From 575c811ca45d6609ef2bbc31427097a976dd1801 Mon Sep 17 00:00:00 2001 From: Peter Neuroth Date: Thu, 19 Mar 2026 00:07:02 +0100 Subject: [PATCH 01/11] plugins(lsps2): add session FSM for LSPS2 payment collection Introduce a state-machine-based approach to managing LSPS2 JIT channel sessions. The FSM tracks payment collection from initial channel open through HTLC forwarding to completion, replacing the previous ad-hoc state tracking. Also adds Sum trait for Msat and PartialEq for protocol types needed by the FSM. Changelog-Experimental: LSPS2 session state machine for JIT channels --- plugins/lsps-plugin/Cargo.toml | 2 +- plugins/lsps-plugin/src/core/lsps2/session.rs | 1908 +++++++++++++++++ plugins/lsps-plugin/src/proto/lsps0.rs | 7 + plugins/lsps-plugin/src/proto/lsps2.rs | 40 +- 4 files changed, 1949 insertions(+), 8 deletions(-) create mode 100644 plugins/lsps-plugin/src/core/lsps2/session.rs diff --git a/plugins/lsps-plugin/Cargo.toml b/plugins/lsps-plugin/Cargo.toml index 7fd074f1c48b..d1b99ada32ee 100644 --- a/plugins/lsps-plugin/Cargo.toml +++ b/plugins/lsps-plugin/Cargo.toml @@ -25,4 +25,4 @@ rand = "0.9" serde = { version = "1.0", features = ["derive"] } serde_json = { version = "1.0", features = ["raw_value"] } thiserror = "2.0" -tokio = { version = "1.44", features = ["full"] } +tokio = { version = "1.44", features = ["full", "test-util"] } diff --git a/plugins/lsps-plugin/src/core/lsps2/session.rs b/plugins/lsps-plugin/src/core/lsps2/session.rs new file mode 100644 index 000000000000..298448548a8e --- /dev/null +++ b/plugins/lsps-plugin/src/core/lsps2/session.rs @@ -0,0 +1,1908 @@ +//! Lsps2 Service FSM + +use crate::proto::{ + lsps0::Msat, + lsps2::{ + compute_opening_fee, + failure_codes::{TEMPORARY_CHANNEL_FAILURE, UNKNOWN_NEXT_PEER}, + OpeningFeeParams, SessionOutcome, + }, +}; + +#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)] +pub enum Error { + #[error("variable amount payments are not supported")] + UnimplementedVarAmount, + #[error("opening fee computation overflow")] + FeeOverflow, + #[error("invalid state transition")] + InvalidTransition { + state: SessionState, + input: SessionInput, + }, + #[error( + "opening fee {opening_fee_msat} exceeds deductible capacity {deductible_capacity_msat}" + )] + InsufficientDeductibleCapacity { + opening_fee_msat: u64, + deductible_capacity_msat: u128, + }, +} + +type Result = std::result::Result; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PaymentPart { + pub htlc_id: u64, + pub amount_msat: Msat, + pub cltv_expiry: u32, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ForwardPart { + pub htlc_id: u64, + pub fee_msat: u64, + pub forward_msat: u64, +} + +impl From for ForwardPart { + fn from(part: PaymentPart) -> Self { + Self { + htlc_id: part.htlc_id, + fee_msat: 0, + forward_msat: part.amount_msat.msat(), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SessionInput { + /// Htlc intercepted + AddPart { part: PaymentPart }, + /// Timeout waiting for parts to arrive from blip052: defaults to 90s. + CollectTimeout, + /// Channel funding failed. + FundingFailed, + /// Zero-conf channel funded, withheld, and ready. + ChannelReady { + channel_id: String, + funding_psbt: String, + }, + /// The initial payment was successfull + PaymentSettled, + /// The inital payment failed + PaymentFailed, + /// Funding tx was broadcasted + FundingBroadcasted, + /// A new block has been mined. + NewBlock { height: u32 }, + /// The JIT channel has been closed or is no longer in CHANNELD_NORMAL. + ChannelClosed { channel_id: String }, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SessionAction { + FailHtlcs { + failure_code: &'static str, + }, + ForwardHtlcs { + parts: Vec, + channel_id: String, + }, + FundChannel { + peer_id: String, + channel_capacity_msat: Msat, + opening_fee_params: OpeningFeeParams, + }, + FailSession, + AbandonSession { + channel_id: String, + funding_psbt: String, + }, + BroadcastFundingTx { + channel_id: String, + funding_psbt: String, + }, + Disconnect, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SessionEvent { + PaymentPartAdded { + part: PaymentPart, + n_parts: usize, + parts_sum: Msat, + }, + TooManyParts { + n_parts: usize, + }, + PaymentInsufficientForOpeningFee { + opening_fee_msat: u64, + n_parts: usize, + parts_sum: Msat, + }, + CollectTimeout { + n_parts: usize, + parts_sum: Msat, + }, + FundingChannel, + ForwardHtlcs { + channel_id: String, + n_parts: usize, + parts_sum: Msat, + opening_fee_msat: u64, + }, + PaymentSettled { + parts: Vec, + }, + PaymentFailed, + ChannelReady { + channel_id: String, + funding_psbt: String, + }, + FundingBroadcasted { + funding_psbt: String, + }, + SessionFailed, + SessionAbandoned, + SessionSucceeded, + UnsafeHtlcTimeout { + height: u32, + cltv_min: u32, + }, + UnusualInput { + state: String, + input: String, + }, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SessionState { + Collecting { + parts: Vec, + }, + + /// Channel opened in progress, waiting for `channel_ready`. + AwaitingChannelReady { + parts: Vec, + opening_fee_msat: u64, + }, + + /// HTLCs forwarded, waiting for the client to settle or reject. + AwaitingSettlement { + forwarded_parts: Vec, + forwarded_amount_msat: u64, + deducted_fee_msat: u64, + channel_id: String, + funding_psbt: String, + }, + + /// HTLCs got resolved, broadcasting funding tx. + Broadcasting { + channel_id: String, + funding_psbt: String, + }, + + /// Terminal: session failed before a channel was opened. + Failed, + + /// Terminal: session failed after a channel was opened. + Abandoned, + + /// Terminal: session successfully finished + Succeeded, +} + +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub struct ApplyResult { + pub actions: Vec, + pub events: Vec, +} + +impl ApplyResult { + fn unusual_input(state: &SessionState, input: &SessionInput) -> Self { + Self { + events: vec![SessionEvent::UnusualInput { + state: format!("{:?}", state), + input: format!("{:?}", input), + }], + ..Default::default() + } + } +} + +fn cltv_min(parts: &[PaymentPart]) -> Option { + parts.iter().map(|p| p.cltv_expiry).min() +} + +#[derive(Debug)] +pub struct Session { + state: SessionState, + // from BOLT2 + max_parts: usize, + // From the offer/fee_policy + opening_fee_params: OpeningFeeParams, + payment_size_msat: Option, + channel_capacity_msat: Msat, + peer_id: String, +} + +impl Session { + pub fn new( + max_parts: usize, + opening_fee_params: OpeningFeeParams, + payment_size_msat: Option, + channel_capacity_msat: Msat, + peer_id: String, + ) -> Self { + Self { + state: SessionState::Collecting { parts: vec![] }, + max_parts, + opening_fee_params, + payment_size_msat, + channel_capacity_msat, + peer_id, + } + } + + pub fn is_terminal(&self) -> bool { + matches!( + self.state, + SessionState::Failed | SessionState::Abandoned | SessionState::Succeeded + ) + } + + pub fn outcome(&self) -> Option { + match &self.state { + SessionState::Succeeded => Some(SessionOutcome::Succeeded), + SessionState::Abandoned => Some(SessionOutcome::Abandoned), + SessionState::Failed => Some(SessionOutcome::Failed), + _ => None, + } + } + + pub fn apply(&mut self, input: SessionInput) -> Result { + match (&mut self.state, input) { + // + // Collecting transitions. + // + (SessionState::Collecting { parts }, SessionInput::AddPart { part }) => { + if self.payment_size_msat.is_none() { + return Err(Error::UnimplementedVarAmount); + } + + parts.push(part.clone()); + let n_parts = parts.len(); + let parts_sum = parts.iter().map(|p| p.amount_msat).sum(); + + let mut events = vec![SessionEvent::PaymentPartAdded { + part: part.clone(), + n_parts, + parts_sum, + }]; + + // Fail early if we have too many parts. + if n_parts > self.max_parts { + self.state = SessionState::Failed; + events.push(SessionEvent::TooManyParts { n_parts }); + events.push(SessionEvent::SessionFailed); + return Ok(ApplyResult { + actions: vec![ + SessionAction::FailHtlcs { + failure_code: UNKNOWN_NEXT_PEER, + }, + SessionAction::FailSession, + ], + events, + }); + } + + let expected_msat = self.payment_size_msat.unwrap_or_else(|| Msat(0)); // We checked that it isn't None + if parts_sum >= expected_msat { + let opening_fee_msat = compute_opening_fee( + parts_sum.msat(), + self.opening_fee_params.min_fee_msat.msat(), + self.opening_fee_params.proportional.ppm() as u64, + ) + .ok_or(Error::FeeOverflow)?; + + if opening_fee_msat >= parts_sum.msat() + || !is_deductible(parts, opening_fee_msat) + { + self.state = SessionState::Failed; + events.push(SessionEvent::PaymentInsufficientForOpeningFee { + opening_fee_msat, + n_parts, + parts_sum, + }); + events.push(SessionEvent::SessionFailed); + return Ok(ApplyResult { + actions: vec![ + SessionAction::FailHtlcs { + failure_code: UNKNOWN_NEXT_PEER, + }, + SessionAction::FailSession, + ], + events, + }); + } + + // We collected enough parts to fund the channel, transition. + self.state = SessionState::AwaitingChannelReady { + parts: std::mem::take(parts), + opening_fee_msat, + }; + + events.push(SessionEvent::FundingChannel); + + return Ok(ApplyResult { + events, + actions: vec![SessionAction::FundChannel { + peer_id: self.peer_id.clone(), + channel_capacity_msat: self.channel_capacity_msat, + opening_fee_params: self.opening_fee_params.clone(), + }], + }); + } + + // Keep collecting + Ok(ApplyResult { + events, + ..Default::default() + }) + } + (SessionState::Collecting { parts }, SessionInput::CollectTimeout) => { + // Session collection timed out: we fail the session but keep + // the offer active. Next payment can create a new session. + let n_parts = parts.len(); + let parts_sum = parts.iter().map(|p| p.amount_msat).sum(); + + self.state = SessionState::Failed; + Ok(ApplyResult { + actions: vec![ + SessionAction::FailHtlcs { + failure_code: TEMPORARY_CHANNEL_FAILURE, + }, + SessionAction::FailSession, + ], + events: vec![ + SessionEvent::CollectTimeout { n_parts, parts_sum }, + SessionEvent::SessionFailed, + ], + }) + } + (SessionState::Collecting { parts }, SessionInput::NewBlock { height }) => { + if let Some(min) = cltv_min(parts) { + if height > min { + self.state = SessionState::Failed; + return Ok(ApplyResult { + actions: vec![ + SessionAction::FailHtlcs { + failure_code: TEMPORARY_CHANNEL_FAILURE, + }, + SessionAction::FailSession, + ], + events: vec![ + SessionEvent::UnsafeHtlcTimeout { + height, + cltv_min: min, + }, + SessionEvent::SessionFailed, + ], + }); + } + } + // No parts or height <= cltv_min: stay collecting. + Ok(ApplyResult::default()) + } + ( + SessionState::Collecting { .. }, + ref input @ (SessionInput::ChannelReady { .. } + | SessionInput::PaymentSettled + | SessionInput::PaymentFailed + | SessionInput::FundingBroadcasted + | SessionInput::FundingFailed + | SessionInput::ChannelClosed { .. }), + ) => Ok(ApplyResult::unusual_input(&self.state, input)), + + // + // AwaitChannelReady transitions. + // + (SessionState::AwaitingChannelReady { parts, .. }, SessionInput::AddPart { part }) => { + parts.push(part.clone()); + let n_parts = parts.len(); + let parts_sum = parts.iter().map(|p| p.amount_msat).sum(); + + // We don't check for max parts here as we are in the middle of + // the channel funding. We'll check once we transitioned. + + Ok(ApplyResult { + events: vec![SessionEvent::PaymentPartAdded { + part, + n_parts, + parts_sum, + }], + ..Default::default() + }) + } + ( + SessionState::AwaitingChannelReady { + parts, + opening_fee_msat, + }, + SessionInput::ChannelReady { + channel_id, + funding_psbt, + }, + ) => { + // We are transitioning in any case. + let parts = std::mem::take(parts); + let opening_fee_msat = std::mem::take(opening_fee_msat); + + let n_parts = parts.len(); + + let mut events = vec![SessionEvent::ChannelReady { + channel_id: channel_id.clone(), + funding_psbt: funding_psbt.clone(), + }]; + + // Fail if we have too many parts. + if n_parts > self.max_parts { + self.state = SessionState::Abandoned; + events.push(SessionEvent::TooManyParts { n_parts }); + events.push(SessionEvent::SessionAbandoned); + return Ok(ApplyResult { + actions: vec![ + SessionAction::FailHtlcs { + failure_code: UNKNOWN_NEXT_PEER, + }, + SessionAction::Disconnect, + SessionAction::AbandonSession { + channel_id, + funding_psbt, + }, + ], + events, + }); + } + + // Deduct opening_fee_msat. + let forwards = if let Ok(forwards) = allocate_forwards(&parts, opening_fee_msat) { + forwards + } else { + self.state = SessionState::Abandoned; + events.push(SessionEvent::SessionAbandoned); + return Ok(ApplyResult { + actions: vec![ + SessionAction::FailHtlcs { + failure_code: UNKNOWN_NEXT_PEER, + }, + SessionAction::Disconnect, + SessionAction::AbandonSession { + channel_id, + funding_psbt, + }, + ], + events, + }); + }; + + let parts_sum = + Msat::from_msat(forwards.iter().map(|p| p.forward_msat + p.fee_msat).sum()); + + events.push(SessionEvent::ForwardHtlcs { + channel_id: channel_id.clone(), + n_parts, + parts_sum, + opening_fee_msat, + }); + + // Forward HTLCs and await settlement. + self.state = SessionState::AwaitingSettlement { + forwarded_parts: forwards.clone(), + forwarded_amount_msat: forwards.iter().map(|p| p.forward_msat).sum(), + deducted_fee_msat: forwards.iter().map(|p| p.fee_msat).sum(), + channel_id: channel_id.clone(), + funding_psbt: funding_psbt.clone(), + }; + + return Ok(ApplyResult { + actions: vec![SessionAction::ForwardHtlcs { + parts: forwards, + channel_id, + }], + events, + }); + } + ( + SessionState::AwaitingChannelReady { .. }, + ref input @ SessionInput::CollectTimeout, + ) => { + // Collection timeout is only relevant as long as we are still + // collecting parts to cover the fee. Once we opened the channel + // we don't care anymore. + Ok(ApplyResult::unusual_input(&self.state, input)) + } + (SessionState::AwaitingChannelReady { .. }, SessionInput::FundingFailed) => { + self.state = SessionState::Failed; + Ok(ApplyResult { + actions: vec![ + SessionAction::FailHtlcs { + failure_code: UNKNOWN_NEXT_PEER, + }, + SessionAction::Disconnect, + SessionAction::FailSession, + ], + events: vec![SessionEvent::SessionFailed], + }) + } + ( + SessionState::AwaitingChannelReady { parts, .. }, + SessionInput::NewBlock { height }, + ) => { + if let Some(min) = cltv_min(parts) { + if height > min { + self.state = SessionState::Failed; + return Ok(ApplyResult { + actions: vec![ + SessionAction::FailHtlcs { + failure_code: TEMPORARY_CHANNEL_FAILURE, + }, + SessionAction::Disconnect, + SessionAction::FailSession, + ], + events: vec![ + SessionEvent::UnsafeHtlcTimeout { + height, + cltv_min: min, + }, + SessionEvent::SessionFailed, + ], + }); + } + } + Ok(ApplyResult::default()) + } + ( + SessionState::AwaitingChannelReady { .. }, + ref input @ (SessionInput::PaymentSettled + | SessionInput::PaymentFailed + | SessionInput::FundingBroadcasted + | SessionInput::ChannelClosed { .. }), + ) => Ok(ApplyResult::unusual_input(&self.state, input)), + + // + // AwaitingSettlement transitions. + // + ( + SessionState::AwaitingSettlement { + forwarded_parts, + forwarded_amount_msat, + deducted_fee_msat, + channel_id, + .. + }, + SessionInput::AddPart { part }, + ) => { + // We forward late-arriving parts immediately in this state. + let fp = ForwardPart { + htlc_id: part.htlc_id, + fee_msat: 0, + forward_msat: part.amount_msat.msat(), + }; + *forwarded_amount_msat += fp.forward_msat; + *deducted_fee_msat += fp.fee_msat; + forwarded_parts.push(fp.clone()); + + let n_parts = forwarded_parts.len(); + let parts_sum = Msat::from_msat(*forwarded_amount_msat + *deducted_fee_msat); + + // We don't check max_parts here as there is not much we can + // do about this at this stage, we definitely need a: + // TODO: Add integration test for #Htlcs > max_accepted_htlcs + + Ok(ApplyResult { + events: vec![ + SessionEvent::PaymentPartAdded { + part: part.clone(), + n_parts, + parts_sum, + }, + SessionEvent::ForwardHtlcs { + channel_id: channel_id.clone(), + n_parts: 1, + parts_sum: part.amount_msat, + opening_fee_msat: 0, + }, + ], + actions: vec![SessionAction::ForwardHtlcs { + parts: vec![fp], + channel_id: channel_id.clone(), + }], + }) + } + ( + SessionState::AwaitingSettlement { + forwarded_parts, + channel_id, + funding_psbt, + .. + }, + SessionInput::PaymentSettled, + ) => { + let channel_id = std::mem::take(channel_id); + let funding_psbt = std::mem::take(funding_psbt); + let parts = std::mem::take(forwarded_parts); + + self.state = SessionState::Broadcasting { + channel_id: channel_id.clone(), + funding_psbt: funding_psbt.clone(), + }; + + Ok(ApplyResult { + actions: vec![SessionAction::BroadcastFundingTx { + channel_id, + funding_psbt, + }], + events: vec![SessionEvent::PaymentSettled { parts }], + }) + } + ( + SessionState::AwaitingSettlement { + channel_id, + funding_psbt, + .. + }, + SessionInput::PaymentFailed, + ) => { + let channel_id = std::mem::take(channel_id); + let funding_psbt = std::mem::take(funding_psbt); + + // Parts are already forwarded so we can't do anything here. + // Abandon session. + + self.state = SessionState::Abandoned; + + Ok(ApplyResult { + actions: vec![ + SessionAction::AbandonSession { + channel_id, + funding_psbt, + }, + SessionAction::Disconnect, + ], + events: vec![ + SessionEvent::PaymentFailed, + SessionEvent::SessionAbandoned, + ], + }) + } + ( + SessionState::AwaitingSettlement { + channel_id, + funding_psbt, + .. + }, + SessionInput::ChannelClosed { + channel_id: closed_id, + }, + ) if closed_id == *channel_id => { + let channel_id = std::mem::take(channel_id); + let funding_psbt = std::mem::take(funding_psbt); + + self.state = SessionState::Abandoned; + + Ok(ApplyResult { + actions: vec![ + SessionAction::AbandonSession { + channel_id, + funding_psbt, + }, + SessionAction::Disconnect, + ], + events: vec![ + SessionEvent::PaymentFailed, + SessionEvent::SessionAbandoned, + ], + }) + } + ( + SessionState::AwaitingSettlement { .. }, + ref input @ (SessionInput::CollectTimeout + | SessionInput::ChannelReady { .. } + | SessionInput::FundingFailed + | SessionInput::FundingBroadcasted + | SessionInput::ChannelClosed { .. } + | SessionInput::NewBlock { .. }), + ) => Ok(ApplyResult::unusual_input(&self.state, input)), + + // + // Broadcasting transitions. + // + (SessionState::Broadcasting { channel_id, .. }, SessionInput::AddPart { part }) => { + // We already successfully settled htlcs for this payment + // hash, we don't care about max_parts anymore (for whatever + // reason we are collecting more of the same payment hash) + let n_parts = 1; + let parts_sum = part.amount_msat; + + Ok(ApplyResult { + actions: vec![SessionAction::ForwardHtlcs { + parts: vec![ForwardPart { + htlc_id: part.htlc_id, + fee_msat: 0, + forward_msat: part.amount_msat.msat(), + }], + channel_id: channel_id.clone(), + }], + events: vec![ + SessionEvent::PaymentPartAdded { + part: part.clone(), + n_parts, + parts_sum, + }, + SessionEvent::ForwardHtlcs { + channel_id: channel_id.clone(), + n_parts, + parts_sum, + opening_fee_msat: 0, + }, + ], + }) + } + (SessionState::Broadcasting { funding_psbt, .. }, SessionInput::FundingBroadcasted) => { + let funding_psbt = std::mem::take(funding_psbt); + + self.state = SessionState::Succeeded; + Ok(ApplyResult { + actions: vec![], + events: vec![ + SessionEvent::FundingBroadcasted { funding_psbt }, + SessionEvent::SessionSucceeded, + ], + }) + } + ( + SessionState::Broadcasting { .. }, + ref input @ (SessionInput::CollectTimeout + | SessionInput::ChannelReady { .. } + | SessionInput::PaymentSettled + | SessionInput::FundingFailed + | SessionInput::PaymentFailed + | SessionInput::ChannelClosed { .. } + | SessionInput::NewBlock { .. }), + ) => Ok(ApplyResult::unusual_input(&self.state, input)), + + // + // Terminal states. + // + (SessionState::Failed | SessionState::Abandoned | SessionState::Succeeded, input) => { + return Err(Error::InvalidTransition { + state: self.state.clone(), + input, + }) + } + } + } +} + +fn max_deductible(parts: &[PaymentPart]) -> u128 { + parts + .iter() + .map(|p| u128::from(p.amount_msat.msat().saturating_sub(1))) + .sum() +} + +fn is_deductible(parts: &[PaymentPart], opening_fee_msat: u64) -> bool { + max_deductible(parts) >= u128::from(opening_fee_msat) +} + +fn allocate_forwards(parts: &[PaymentPart], opening_fee_msat: u64) -> Result> { + if !is_deductible(parts, opening_fee_msat) { + return Err(Error::InsufficientDeductibleCapacity { + opening_fee_msat, + deductible_capacity_msat: max_deductible(parts), + }); + } + + let mut remaining = opening_fee_msat; + let forwards: Vec = parts + .iter() + .map(|p| { + let amt = p.amount_msat.msat(); + let deduct = remaining.min(amt.saturating_sub(1)); + remaining -= deduct; + ForwardPart { + htlc_id: p.htlc_id, + fee_msat: deduct, + forward_msat: amt - deduct, + } + }) + .collect(); + + debug_assert_eq!(remaining, 0); + Ok(forwards) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::proto::lsps0::Ppm; + use crate::proto::lsps2::Promise; + use chrono::{Duration, Utc}; + + fn part(htlc_id: u64, amount_msat: u64) -> PaymentPart { + PaymentPart { + htlc_id, + amount_msat: Msat::from_msat(amount_msat), + cltv_expiry: 100, + } + } + + fn part_with_cltv(htlc_id: u64, amount_msat: u64, cltv_expiry: u32) -> PaymentPart { + PaymentPart { + htlc_id, + amount_msat: Msat::from_msat(amount_msat), + cltv_expiry, + } + } + + fn opening_fee_params(min_fee_msat: u64, proportional_ppm: u32) -> OpeningFeeParams { + OpeningFeeParams { + min_fee_msat: Msat::from_msat(min_fee_msat), + proportional: Ppm::from_ppm(proportional_ppm), + valid_until: Utc::now() + Duration::hours(1), + min_lifetime: 144, + max_client_to_self_delay: 2016, + min_payment_size_msat: Msat::from_msat(1), + max_payment_size_msat: Msat::from_msat(u64::MAX), + promise: Promise("test-promise".to_owned()), + } + } + + fn session(max_parts: usize, payment_size_msat: Option, min_fee_msat: u64) -> Session { + Session { + state: SessionState::Collecting { parts: vec![] }, + max_parts, + opening_fee_params: opening_fee_params(min_fee_msat, 1_000), + payment_size_msat: payment_size_msat.map(Msat::from_msat), + channel_capacity_msat: Msat::from_msat(100_000_000), + peer_id: "peer-1".to_owned(), + } + } + + #[test] + fn collecting_add_part_emits_payment_part_added() { + let mut s = session(3, Some(2_000), 1); + let p = part(1, 1_000); + + let res = s.apply(SessionInput::AddPart { part: p.clone() }).unwrap(); + + assert!(res.actions.is_empty()); + assert_eq!( + res.events, + vec![SessionEvent::PaymentPartAdded { + part: p, + n_parts: 1, + parts_sum: Msat::from_msat(1_000), + }] + ); + } + + #[test] + fn collecting_below_expected_stays_collecting_no_actions() { + let mut s = session(3, Some(2_000), 1); + + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + + assert!(matches!(s.state, SessionState::Collecting { .. })); + } + + #[test] + fn collecting_reaches_expected_transitions_and_funds_channel() { + let mut s = session(3, Some(2_000), 1); + + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + let res = s + .apply(SessionInput::AddPart { + part: part(2, 1_000), + }) + .unwrap(); + + assert!(matches!(s.state, SessionState::AwaitingChannelReady { .. })); + assert_eq!(res.actions.len(), 1); + match &res.actions[0] { + SessionAction::FundChannel { + peer_id, + channel_capacity_msat, + opening_fee_params, + } => { + assert_eq!(peer_id, "peer-1"); + assert_eq!(*channel_capacity_msat, Msat::from_msat(100_000_000)); + assert_eq!(opening_fee_params.min_fee_msat, Msat::from_msat(1)); + assert_eq!(opening_fee_params.proportional, Ppm::from_ppm(1_000)); + assert_eq!(opening_fee_params.min_payment_size_msat, Msat::from_msat(1)); + assert_eq!( + opening_fee_params.max_payment_size_msat, + Msat::from_msat(u64::MAX) + ); + assert_eq!( + opening_fee_params.promise, + Promise("test-promise".to_owned()) + ); + } + _ => panic!("expected FundChannel action"), + } + assert!(res.events.contains(&SessionEvent::FundingChannel)); + } + + #[test] + fn collecting_too_many_parts_emits_fail_action() { + let mut s = session(0, Some(1_000), 1); + + let res = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + + assert_eq!( + res.events, + vec![ + SessionEvent::PaymentPartAdded { + part: part(1, 1_000), + n_parts: 1, + parts_sum: Msat::from_msat(1_000), + }, + SessionEvent::TooManyParts { n_parts: 1 }, + SessionEvent::SessionFailed, + ] + ); + assert_eq!( + res.actions, + vec![ + SessionAction::FailHtlcs { + failure_code: UNKNOWN_NEXT_PEER + }, + SessionAction::FailSession + ] + ); + } + + #[test] + fn collecting_insufficient_for_opening_fee_emits_fail_action() { + let mut s = session(3, Some(1_000), 1_000); + + let res = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + + assert_eq!( + res.events, + vec![ + SessionEvent::PaymentPartAdded { + part: part(1, 1_000), + n_parts: 1, + parts_sum: Msat::from_msat(1_000), + }, + SessionEvent::PaymentInsufficientForOpeningFee { + opening_fee_msat: 1_000, + n_parts: 1, + parts_sum: Msat::from_msat(1_000), + }, + SessionEvent::SessionFailed, + ] + ); + assert_eq!( + res.actions, + vec![ + SessionAction::FailHtlcs { + failure_code: UNKNOWN_NEXT_PEER + }, + SessionAction::FailSession, + ] + ); + } + + #[test] + fn collecting_collect_timeout_with_no_parts_fails_and_transitions_failed() { + let mut s = session(3, Some(2_000), 1); + + let res = s.apply(SessionInput::CollectTimeout).unwrap(); + + assert!(matches!(s.state, SessionState::Failed)); + assert_eq!( + res.events, + vec![ + SessionEvent::CollectTimeout { + n_parts: 0, + parts_sum: Msat::from_msat(0), + }, + SessionEvent::SessionFailed, + ] + ); + assert_eq!( + res.actions, + vec![ + SessionAction::FailHtlcs { + failure_code: TEMPORARY_CHANNEL_FAILURE, + }, + SessionAction::FailSession, + ] + ); + } + + #[test] + fn collecting_collect_timeout_with_parts_reports_count_and_sum() { + let mut s = session(3, Some(5_000), 1); + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 2_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::AddPart { + part: part(2, 1_000), + }) + .unwrap(); + + let res = s.apply(SessionInput::CollectTimeout).unwrap(); + + assert!(matches!(s.state, SessionState::Failed)); + assert_eq!( + res.events, + vec![ + SessionEvent::CollectTimeout { + n_parts: 2, + parts_sum: Msat::from_msat(3_000), + }, + SessionEvent::SessionFailed, + ] + ); + assert_eq!( + res.actions, + vec![ + SessionAction::FailHtlcs { + failure_code: TEMPORARY_CHANNEL_FAILURE, + }, + SessionAction::FailSession, + ] + ); + } + + #[test] + fn failed_rejects_add_part_with_invalid_transition() { + let mut s = session(3, Some(2_000), 1); + s.state = SessionState::Failed; + + let err = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap_err(); + + assert_eq!( + err, + Error::InvalidTransition { + state: SessionState::Failed, + input: SessionInput::AddPart { + part: part(1, 1_000), + }, + } + ); + } + + #[test] + fn failed_rejects_collect_timeout_with_invalid_transition() { + let mut s = session(3, Some(2_000), 1); + s.state = SessionState::Failed; + + let err = s.apply(SessionInput::CollectTimeout).unwrap_err(); + assert_eq!( + err, + Error::InvalidTransition { + state: SessionState::Failed, + input: SessionInput::CollectTimeout, + } + ); + } + + #[test] + fn collecting_payment_size_none_errors_without_mutating_state() { + let mut s = session(3, None, 1); + let err = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap_err(); + + assert_eq!(err, Error::UnimplementedVarAmount); + assert_eq!(s.state, SessionState::Collecting { parts: vec![] }); + } + + #[test] + fn collecting_fee_overflow_returns_fee_overflow() { + let mut s = session(3, Some(u64::MAX), 1); + s.opening_fee_params.proportional = Ppm::from_ppm(u32::MAX); + + let err = s + .apply(SessionInput::AddPart { + part: part(1, u64::MAX), + }) + .unwrap_err(); + assert_eq!(err, Error::FeeOverflow); + } + + #[test] + fn collecting_unexpected_inputs_emit_unusual_input() { + let mut s = session(3, Some(2_000), 1); + + let res = s + .apply(SessionInput::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }) + .unwrap(); + + assert!(matches!(s.state, SessionState::Collecting { .. })); + assert!(res.actions.is_empty()); + assert_eq!(res.events.len(), 1); + assert!(matches!(&res.events[0], SessionEvent::UnusualInput { .. })); + } + + #[test] + fn channel_ready_forwards_all_parts_and_transitions_to_awaiting_settlement() { + let mut s = session(4, Some(2_000), 1); + + let p1 = part(1, 1_000); + let p2 = part(2, 1_000); + let p3 = part(3, 500); + + let _ = s.apply(SessionInput::AddPart { part: p1.clone() }).unwrap(); + let _ = s.apply(SessionInput::AddPart { part: p2.clone() }).unwrap(); + let _ = s.apply(SessionInput::AddPart { part: p3.clone() }).unwrap(); + + let res = s + .apply(SessionInput::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }) + .unwrap(); + + assert_eq!( + s.state, + SessionState::AwaitingSettlement { + forwarded_parts: vec![ + ForwardPart { + htlc_id: p1.htlc_id, + fee_msat: 2, + forward_msat: 998, + }, + ForwardPart { + htlc_id: p2.htlc_id, + fee_msat: 0, + forward_msat: 1_000, + }, + ForwardPart { + htlc_id: p3.htlc_id, + fee_msat: 0, + forward_msat: 500, + }, + ], + forwarded_amount_msat: 2_498, + deducted_fee_msat: 2, + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + } + ); + + assert_eq!( + res.actions, + vec![SessionAction::ForwardHtlcs { + parts: vec![ + ForwardPart { + htlc_id: p1.htlc_id, + fee_msat: 2, + forward_msat: 998, + }, + ForwardPart { + htlc_id: p2.htlc_id, + fee_msat: 0, + forward_msat: 1_000, + }, + ForwardPart { + htlc_id: p3.htlc_id, + fee_msat: 0, + forward_msat: 500, + }, + ], + channel_id: "chan-1".to_owned(), + }] + ); + assert_eq!( + res.events, + vec![ + SessionEvent::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }, + SessionEvent::ForwardHtlcs { + channel_id: "chan-1".to_owned(), + n_parts: 3, + parts_sum: Msat::from_msat(2_500), + opening_fee_msat: 2, + }, + ] + ); + } + + #[test] + fn awaiting_settlement_add_part_forwards_single_part() { + let mut s = session(5, Some(2_000), 1); + + let p1 = part(1, 1_000); + let p2 = part(2, 1_000); + let p3 = part(3, 500); + + let _ = s.apply(SessionInput::AddPart { part: p1.clone() }).unwrap(); + let _ = s.apply(SessionInput::AddPart { part: p2.clone() }).unwrap(); + let _ = s + .apply(SessionInput::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }) + .unwrap(); + + let res = s.apply(SessionInput::AddPart { part: p3.clone() }).unwrap(); + + assert_eq!( + s.state, + SessionState::AwaitingSettlement { + forwarded_parts: vec![ + ForwardPart { + htlc_id: p1.htlc_id, + fee_msat: 2, + forward_msat: 998, + }, + ForwardPart { + htlc_id: p2.htlc_id, + fee_msat: 0, + forward_msat: 1_000, + }, + ForwardPart { + htlc_id: p3.htlc_id, + fee_msat: 0, + forward_msat: 500, + }, + ], + forwarded_amount_msat: 2_498, + deducted_fee_msat: 2, + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + } + ); + + assert_eq!( + res.actions, + vec![SessionAction::ForwardHtlcs { + parts: vec![p3.clone().into()], + channel_id: "chan-1".to_owned(), + }] + ); + assert_eq!( + res.events, + vec![ + SessionEvent::PaymentPartAdded { + part: p3.clone(), + n_parts: 3, + parts_sum: Msat::from_msat(2_500), + }, + SessionEvent::ForwardHtlcs { + channel_id: "chan-1".to_owned(), + n_parts: 1, + parts_sum: Msat::from_msat(500), + opening_fee_msat: 0, + }, + ] + ); + } + + #[test] + fn allocate_forwards_allows_exact_deductible_capacity() { + let parts = vec![part(1, 1_000), part(2, 1_000)]; + + let forwards = allocate_forwards(&parts, 1_998).unwrap(); + + assert_eq!( + forwards, + vec![ + ForwardPart { + htlc_id: 1, + fee_msat: 999, + forward_msat: 1, + }, + ForwardPart { + htlc_id: 2, + fee_msat: 999, + forward_msat: 1, + }, + ] + ); + } + + #[test] + fn payment_settled_transitions_to_broadcasting_and_emits_broadcast_action() { + let mut s = session(4, Some(2_000), 1); + + let p1 = part(1, 1_000); + let p2 = part(2, 1_000); + let _ = s.apply(SessionInput::AddPart { part: p1.clone() }).unwrap(); + let _ = s.apply(SessionInput::AddPart { part: p2.clone() }).unwrap(); + let _ = s + .apply(SessionInput::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }) + .unwrap(); + + let res = s.apply(SessionInput::PaymentSettled).unwrap(); + + assert_eq!( + s.state, + SessionState::Broadcasting { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + } + ); + assert_eq!( + res.actions, + vec![SessionAction::BroadcastFundingTx { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }] + ); + assert_eq!( + res.events, + vec![SessionEvent::PaymentSettled { + parts: vec![ + ForwardPart { + htlc_id: p1.htlc_id, + fee_msat: 2, + forward_msat: 998, + }, + ForwardPart { + htlc_id: p2.htlc_id, + fee_msat: 0, + forward_msat: 1_000, + }, + ] + }] + ); + } + + #[test] + fn channel_ready_with_too_many_parts_abandons_session_and_fails_htlcs() { + let mut s = session(2, Some(2_000), 1); + + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::AddPart { + part: part(2, 1_000), + }) + .unwrap(); + // Extra part while awaiting channel ready. + let _ = s + .apply(SessionInput::AddPart { part: part(3, 500) }) + .unwrap(); + + let res = s + .apply(SessionInput::ChannelReady { + channel_id: "chan-overflow".to_owned(), + funding_psbt: "psbt-overflow".to_owned(), + }) + .unwrap(); + + assert_eq!(s.state, SessionState::Abandoned); + assert_eq!( + res.events, + vec![ + SessionEvent::ChannelReady { + channel_id: "chan-overflow".to_owned(), + funding_psbt: "psbt-overflow".to_owned(), + }, + SessionEvent::TooManyParts { n_parts: 3 }, + SessionEvent::SessionAbandoned, + ] + ); + assert_eq!( + res.actions, + vec![ + SessionAction::FailHtlcs { + failure_code: UNKNOWN_NEXT_PEER, + }, + SessionAction::Disconnect, + SessionAction::AbandonSession { + channel_id: "chan-overflow".to_owned(), + funding_psbt: "psbt-overflow".to_owned(), + }, + ] + ); + } + + #[test] + fn abandoned_rejects_further_inputs_with_invalid_transition() { + let mut s = session(2, Some(2_000), 1); + s.state = SessionState::Abandoned; + + let err = s + .apply(SessionInput::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }) + .unwrap_err(); + + assert_eq!( + err, + Error::InvalidTransition { + state: SessionState::Abandoned, + input: SessionInput::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }, + } + ); + } + + #[test] + fn broadcasting_add_part_forwards_single_htlc() { + let mut s = session(4, Some(2_000), 1); + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::AddPart { + part: part(2, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }) + .unwrap(); + let _ = s.apply(SessionInput::PaymentSettled).unwrap(); + + let p3 = part(3, 500); + let res = s.apply(SessionInput::AddPart { part: p3.clone() }).unwrap(); + + assert_eq!( + res.actions, + vec![SessionAction::ForwardHtlcs { + parts: vec![p3.clone().into()], + channel_id: "chan-1".to_owned(), + }] + ); + assert_eq!( + res.events, + vec![ + SessionEvent::PaymentPartAdded { + part: p3.clone(), + n_parts: 1, + parts_sum: Msat::from_msat(500), + }, + SessionEvent::ForwardHtlcs { + channel_id: "chan-1".to_owned(), + n_parts: 1, + parts_sum: Msat::from_msat(500), + opening_fee_msat: 0, + }, + ] + ); + } + + #[test] + fn funding_broadcasted_transitions_to_succeeded() { + let mut s = session(4, Some(2_000), 1); + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::AddPart { + part: part(2, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }) + .unwrap(); + let _ = s.apply(SessionInput::PaymentSettled).unwrap(); + + let res = s.apply(SessionInput::FundingBroadcasted).unwrap(); + + assert_eq!(s.state, SessionState::Succeeded); + assert_eq!(res.actions, vec![]); + assert_eq!( + res.events, + vec![ + SessionEvent::FundingBroadcasted { + funding_psbt: "psbt-1".to_owned(), + }, + SessionEvent::SessionSucceeded, + ] + ); + } + + #[test] + fn succeeded_rejects_new_inputs_with_invalid_transition() { + let mut s = session(4, Some(2_000), 1); + s.state = SessionState::Succeeded; + + let err = s + .apply(SessionInput::AddPart { + part: part(99, 1_000), + }) + .unwrap_err(); + + assert_eq!( + err, + Error::InvalidTransition { + state: SessionState::Succeeded, + input: SessionInput::AddPart { + part: part(99, 1_000), + }, + } + ); + } + + #[test] + fn funding_failed_in_awaiting_channel_ready_fails_htlcs_and_transitions_to_failed() { + let mut s = session(3, Some(2_000), 1); + + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::AddPart { + part: part(2, 1_000), + }) + .unwrap(); + + assert!(matches!(s.state, SessionState::AwaitingChannelReady { .. })); + + let res = s.apply(SessionInput::FundingFailed).unwrap(); + + assert_eq!(s.state, SessionState::Failed); + assert_eq!( + res.actions, + vec![ + SessionAction::FailHtlcs { + failure_code: UNKNOWN_NEXT_PEER, + }, + SessionAction::Disconnect, + SessionAction::FailSession, + ] + ); + assert_eq!(res.events, vec![SessionEvent::SessionFailed]); + } + + #[test] + fn funding_failed_in_awaiting_channel_ready_with_extra_parts_reports_all() { + let mut s = session(5, Some(2_000), 1); + + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::AddPart { + part: part(2, 1_000), + }) + .unwrap(); + // Extra part arrived while awaiting channel ready. + let _ = s + .apply(SessionInput::AddPart { part: part(3, 500) }) + .unwrap(); + + let res = s.apply(SessionInput::FundingFailed).unwrap(); + + assert_eq!(s.state, SessionState::Failed); + assert_eq!(res.events, vec![SessionEvent::SessionFailed]); + } + + #[test] + fn collecting_unexpected_funding_failed_emits_unusual_input() { + let mut s = session(3, Some(2_000), 1); + + let res = s.apply(SessionInput::FundingFailed).unwrap(); + + assert!(matches!(s.state, SessionState::Collecting { .. })); + assert!(res.actions.is_empty()); + assert_eq!(res.events.len(), 1); + assert!(matches!(&res.events[0], SessionEvent::UnusualInput { .. })); + } + + #[test] + fn funding_failed_is_terminal() { + let mut s = session(3, Some(2_000), 1); + + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::AddPart { + part: part(2, 1_000), + }) + .unwrap(); + let _ = s.apply(SessionInput::FundingFailed).unwrap(); + + assert!(s.is_terminal()); + + let err = s + .apply(SessionInput::AddPart { part: part(3, 500) }) + .unwrap_err(); + assert!(matches!(err, Error::InvalidTransition { .. })); + } + + #[test] + fn new_block_collecting_timeout_fails_session() { + let mut s = session(3, Some(2_000), 1); + + let _ = s + .apply(SessionInput::AddPart { + part: part_with_cltv(1, 1_000, 50), + }) + .unwrap(); + + let res = s.apply(SessionInput::NewBlock { height: 51 }).unwrap(); + + assert_eq!(s.state, SessionState::Failed); + assert_eq!( + res.events, + vec![ + SessionEvent::UnsafeHtlcTimeout { + height: 51, + cltv_min: 50, + }, + SessionEvent::SessionFailed, + ] + ); + assert_eq!( + res.actions, + vec![ + SessionAction::FailHtlcs { + failure_code: TEMPORARY_CHANNEL_FAILURE, + }, + SessionAction::FailSession, + ] + ); + } + + #[test] + fn new_block_collecting_safe_height_is_noop() { + let mut s = session(3, Some(2_000), 1); + + let _ = s + .apply(SessionInput::AddPart { + part: part_with_cltv(1, 1_000, 50), + }) + .unwrap(); + + let res = s.apply(SessionInput::NewBlock { height: 49 }).unwrap(); + + assert!(matches!(s.state, SessionState::Collecting { .. })); + assert!(res.actions.is_empty()); + assert!(res.events.is_empty()); + } + + #[test] + fn new_block_collecting_no_parts_is_noop() { + let mut s = session(3, Some(2_000), 1); + + let res = s.apply(SessionInput::NewBlock { height: 100 }).unwrap(); + + assert!(matches!(s.state, SessionState::Collecting { .. })); + assert!(res.actions.is_empty()); + assert!(res.events.is_empty()); + } + + #[test] + fn new_block_awaiting_channel_ready_timeout_fails_with_disconnect() { + let mut s = session(3, Some(2_000), 1); + + let _ = s + .apply(SessionInput::AddPart { + part: part_with_cltv(1, 1_000, 50), + }) + .unwrap(); + let _ = s + .apply(SessionInput::AddPart { + part: part_with_cltv(2, 1_000, 60), + }) + .unwrap(); + + assert!(matches!(s.state, SessionState::AwaitingChannelReady { .. })); + + let res = s.apply(SessionInput::NewBlock { height: 51 }).unwrap(); + + assert_eq!(s.state, SessionState::Failed); + assert_eq!( + res.actions, + vec![ + SessionAction::FailHtlcs { + failure_code: TEMPORARY_CHANNEL_FAILURE, + }, + SessionAction::Disconnect, + SessionAction::FailSession, + ] + ); + assert_eq!( + res.events, + vec![ + SessionEvent::UnsafeHtlcTimeout { + height: 51, + cltv_min: 50, + }, + SessionEvent::SessionFailed, + ] + ); + } + + #[test] + fn new_block_awaiting_settlement_emits_unusual_input() { + let mut s = session(4, Some(2_000), 1); + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::AddPart { + part: part(2, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }) + .unwrap(); + + let res = s.apply(SessionInput::NewBlock { height: 200 }).unwrap(); + + assert!(matches!(s.state, SessionState::AwaitingSettlement { .. })); + assert!(res.actions.is_empty()); + assert_eq!(res.events.len(), 1); + assert!(matches!(&res.events[0], SessionEvent::UnusualInput { .. })); + } + + #[test] + fn awaiting_settlement_payment_failed_disconnects() { + let mut s = session(4, Some(2_000), 1); + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::AddPart { + part: part(2, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }) + .unwrap(); + + let res = s.apply(SessionInput::PaymentFailed).unwrap(); + + assert_eq!(s.state, SessionState::Abandoned); + assert_eq!( + res.actions, + vec![ + SessionAction::AbandonSession { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }, + SessionAction::Disconnect, + ] + ); + assert_eq!( + res.events, + vec![SessionEvent::PaymentFailed, SessionEvent::SessionAbandoned] + ); + } + + #[test] + fn awaiting_settlement_unusual_inputs_emit_unusual_input() { + let mut s = session(4, Some(2_000), 1); + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::AddPart { + part: part(2, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }) + .unwrap(); + + for input in [ + SessionInput::CollectTimeout, + SessionInput::FundingFailed, + SessionInput::FundingBroadcasted, + SessionInput::NewBlock { height: 100 }, + ] { + let res = s.apply(input).unwrap(); + assert!(res.actions.is_empty()); + assert_eq!(res.events.len(), 1); + assert!(matches!(&res.events[0], SessionEvent::UnusualInput { .. })); + } + } + + #[test] + fn broadcasting_unusual_inputs_emit_unusual_input() { + let mut s = session(4, Some(2_000), 1); + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::AddPart { + part: part(2, 1_000), + }) + .unwrap(); + let _ = s + .apply(SessionInput::ChannelReady { + channel_id: "chan-1".to_owned(), + funding_psbt: "psbt-1".to_owned(), + }) + .unwrap(); + let _ = s.apply(SessionInput::PaymentSettled).unwrap(); + + for input in [ + SessionInput::CollectTimeout, + SessionInput::FundingFailed, + SessionInput::PaymentFailed, + SessionInput::NewBlock { height: 100 }, + ] { + let res = s.apply(input).unwrap(); + assert!(res.actions.is_empty()); + assert_eq!(res.events.len(), 1); + assert!(matches!(&res.events[0], SessionEvent::UnusualInput { .. })); + } + } +} diff --git a/plugins/lsps-plugin/src/proto/lsps0.rs b/plugins/lsps-plugin/src/proto/lsps0.rs index 2cb72812931f..96ed4ef068db 100644 --- a/plugins/lsps-plugin/src/proto/lsps0.rs +++ b/plugins/lsps-plugin/src/proto/lsps0.rs @@ -1,6 +1,7 @@ use crate::proto::jsonrpc::{JsonRpcRequest, RpcError}; use core::fmt; use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; +use std::iter::Sum; use thiserror::Error; const MSAT_PER_SAT: u64 = 1_000; @@ -100,6 +101,12 @@ impl core::fmt::Display for Msat { } } +impl Sum for Msat { + fn sum>(iter: I) -> Self { + Msat(iter.map(|x| x.0).sum()) + } +} + impl Serialize for Msat { fn serialize(&self, serializer: S) -> std::result::Result where diff --git a/plugins/lsps-plugin/src/proto/lsps2.rs b/plugins/lsps-plugin/src/proto/lsps2.rs index 82767a27ea53..1a042545b162 100644 --- a/plugins/lsps-plugin/src/proto/lsps2.rs +++ b/plugins/lsps-plugin/src/proto/lsps2.rs @@ -113,7 +113,7 @@ impl core::fmt::Display for PromiseError { impl core::error::Error for PromiseError {} -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Eq)] #[serde(try_from = "String")] pub struct Promise(pub String); @@ -161,7 +161,7 @@ impl core::fmt::Display for Promise { /// Represents a set of parameters for calculating the opening fee for a JIT /// channel. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[serde(deny_unknown_fields)] // LSPS2 requires the client to fail if a field is unrecognized. pub struct OpeningFeeParams { pub min_fee_msat: Msat, @@ -280,15 +280,14 @@ pub struct Lsps2PolicyGetInfoResponse { } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct Lsps2PolicyGetChannelCapacityRequest { +pub struct Lsps2PolicyBuyRequest { pub opening_fee_params: OpeningFeeParams, - pub init_payment_size: Msat, - pub scid: ShortChannelId, + pub payment_size_msat: Option, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct Lsps2PolicyGetChannelCapacityResponse { - pub channel_capacity_msat: Option, +pub struct Lsps2PolicyBuyResponse { + pub channel_capacity_msat: Option, } /// An internal representation of a policy of parameters for calculating the @@ -342,6 +341,33 @@ pub struct DatastoreEntry { pub opening_fee_params: OpeningFeeParams, #[serde(skip_serializing_if = "Option::is_none")] pub expected_payment_size: Option, + pub channel_capacity_msat: Msat, + pub created_at: DateTime, + #[serde(skip_serializing_if = "Option::is_none")] + pub channel_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub funding_psbt: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub funding_txid: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(default)] + pub preimage: Option, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum SessionOutcome { + Succeeded, + Abandoned, + Failed, + Timeout, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct FinalizedDatastoreEntry { + #[serde(flatten)] + pub entry: DatastoreEntry, + pub outcome: SessionOutcome, + pub finalized_at: DateTime, } /// Computes the opening fee in millisatoshis as described in LSPS2. From a11d473689ddaa4189b6c1a82d2e78d41e4f85f1 Mon Sep 17 00:00:00 2001 From: Peter Neuroth Date: Thu, 19 Mar 2026 00:07:12 +0100 Subject: [PATCH 02/11] plugins(lsps2): add session actor with action executor boundary Introduce a session actor that runs the FSM in an async task and communicates side effects through an ActionExecutor trait. This separates state machine logic from I/O concerns like RPC calls and datastore writes. --- plugins/lsps-plugin/src/core/lsps2/actor.rs | 495 ++++++++++++++++++++ 1 file changed, 495 insertions(+) create mode 100644 plugins/lsps-plugin/src/core/lsps2/actor.rs diff --git a/plugins/lsps-plugin/src/core/lsps2/actor.rs b/plugins/lsps-plugin/src/core/lsps2/actor.rs new file mode 100644 index 000000000000..3df39c45a5d7 --- /dev/null +++ b/plugins/lsps-plugin/src/core/lsps2/actor.rs @@ -0,0 +1,495 @@ +use crate::{ + core::lsps2::{ + provider::DatastoreProvider, + session::{PaymentPart, Session, SessionAction, SessionInput}, + }, + proto::{ + lsps0::{Msat, ShortChannelId}, + lsps2::OpeningFeeParams, + }, +}; +use anyhow::Result; +use async_trait::async_trait; +use log::{debug, warn}; +use std::{collections::HashMap, sync::Arc, time::Duration}; +use tokio::{ + sync::{mpsc, oneshot}, + task::JoinHandle, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum HtlcResponse { + Forward { + channel_id: String, + fee_msat: u64, + forward_msat: u64, + }, + Fail { + failure_code: &'static str, + }, + Continue, +} + +enum ActorInput { + AddPart { + part: PaymentPart, + reply_tx: oneshot::Sender, + }, + CollectTimeout, + ChannelReady { + channel_id: String, + funding_psbt: String, + }, + FundingFailed, + PaymentSettled { preimage: Option }, + PaymentFailed, + FundingBroadcasted, + NewBlock { + height: u32, + }, + ChannelClosed { + channel_id: String, + }, +} + +/// Adapter for FSM side-effect actions. +#[async_trait] +pub trait ActionExecutor { + async fn fund_channel( + &self, + peer_id: String, + channel_capacity_msat: Msat, + opening_fee_params: OpeningFeeParams, + ) -> Result<(String, String)>; + + async fn abandon_session(&self, channel_id: String, funding_psbt: String) -> Result<()>; + + async fn broadcast_tx(&self, channel_id: String, funding_psbt: String) -> Result; + + async fn disconnect(&self, peer_id: String) -> Result<()>; + + async fn is_channel_alive(&self, channel_id: &str) -> Result; +} + +#[derive(Debug, Clone)] +pub struct ActorInboxHandle { + tx: mpsc::Sender, +} + +impl ActorInboxHandle { + pub async fn add_part(&self, part: PaymentPart) -> Result { + let (reply_tx, rx) = oneshot::channel(); + self.tx.send(ActorInput::AddPart { part, reply_tx }).await?; + Ok(rx.await?) + } + + pub async fn payment_settled(&self, preimage: Option) -> Result<()> { + Ok(self.tx.send(ActorInput::PaymentSettled { preimage }).await?) + } + + pub async fn payment_failed(&self) -> Result<()> { + Ok(self.tx.send(ActorInput::PaymentFailed).await?) + } + + pub async fn new_block(&self, height: u32) -> Result<()> { + Ok(self.tx.send(ActorInput::NewBlock { height }).await?) + } +} + +/// Per-session actor that drives the LSPS2 syncronous session FSM and bridges +/// it to async side effects. +/// +/// It's the runtime boundary around a single `Session`. It owns input ordering, +/// pending HTLC replies, timeout handling, and execution of FMS-emitted side +/// effects and actions. +pub struct SessionActor { + session: Session, + inbox: mpsc::Receiver, + pending_htlcs: HashMap>, + collect_timeout_handle: Option>, + channel_poll_handle: Option>, + self_send: mpsc::Sender, + executor: A, + peer_id: String, + collect_timeout_secs: u64, + scid: ShortChannelId, + datastore: D, +} + +impl + SessionActor +{ + pub fn spawn_session_actor( + session: Session, + executor: A, + peer_id: String, + collect_timeout_secs: u64, + scid: ShortChannelId, + datastore: D, + ) -> ActorInboxHandle { + let (tx, inbox) = mpsc::channel(128); // Should we use max_htlcs? + let actor = SessionActor { + session, + inbox, + pending_htlcs: HashMap::new(), + collect_timeout_handle: None, + channel_poll_handle: None, + self_send: tx.clone(), + executor, + peer_id, + collect_timeout_secs, + scid, + datastore, + }; + tokio::spawn(actor.run()); + ActorInboxHandle { tx } + } + + fn start_collect_timeout(&mut self) { + let tx = self.self_send.clone(); + let timeout = Duration::from_secs(self.collect_timeout_secs); + self.collect_timeout_handle = Some(tokio::spawn(async move { + tokio::time::sleep(timeout).await; + let _ = tx.send(ActorInput::CollectTimeout).await; + })); + } + + fn cancel_collect_timeout(&mut self) { + if let Some(handle) = self.collect_timeout_handle.take() { + handle.abort(); + } + } + + fn start_channel_poll(&mut self, channel_id: String) { + let tx = self.self_send.clone(); + let executor = self.executor.clone(); + self.channel_poll_handle = Some(tokio::spawn(async move { + let interval = Duration::from_secs(5); + loop { + tokio::time::sleep(interval).await; + match executor.is_channel_alive(&channel_id).await { + Ok(true) => continue, + Ok(false) | Err(_) => { + let _ = tx + .send(ActorInput::ChannelClosed { + channel_id: channel_id.clone(), + }) + .await; + break; + } + } + } + })); + } + + fn cancel_channel_poll(&mut self) { + if let Some(handle) = self.channel_poll_handle.take() { + handle.abort(); + } + } + + async fn run(mut self) { + self.start_collect_timeout(); + while let Some(input) = self.inbox.recv().await { + let input = match input { + ActorInput::AddPart { part, reply_tx } => { + let htlc_id = part.htlc_id; + self.pending_htlcs.insert(htlc_id, reply_tx); + SessionInput::AddPart { part } + } + ActorInput::CollectTimeout => SessionInput::CollectTimeout, + ActorInput::ChannelReady { + channel_id, + funding_psbt, + } => SessionInput::ChannelReady { + channel_id, + funding_psbt, + }, + ActorInput::FundingFailed => SessionInput::FundingFailed, + ActorInput::PaymentSettled { preimage } => { + if let Some(ref pre) = preimage { + let datastore = self.datastore.clone(); + let scid = self.scid; + let pre = pre.clone(); + tokio::spawn(async move { + if let Err(e) = datastore.update_session_preimage(&scid, &pre).await { + warn!("update_session_preimage failed for scid={scid}: {e}"); + } + }); + } + SessionInput::PaymentSettled + } + ActorInput::PaymentFailed => SessionInput::PaymentFailed, + ActorInput::FundingBroadcasted => SessionInput::FundingBroadcasted, + ActorInput::NewBlock { height } => SessionInput::NewBlock { height }, + ActorInput::ChannelClosed { channel_id } => { + SessionInput::ChannelClosed { channel_id } + } + }; + + match self.session.apply(input) { + Ok(result) => { + for event in &result.events { + // Note: Add event handler later on. + debug!("session event: {:?}", event); + } + + for action in result.actions { + self.execute_action(action); + } + + if self.session.is_terminal() { + break; + } + } + Err(e) => { + warn!("session FSM error: {e}"); + if self.session.is_terminal() { + self.release_pending_htlcs(); + break; + } + } + } + } + + // We exited the loop, just continue all held HTLCs and let the handler + // decide. + self.release_pending_htlcs(); + + if let Some(outcome) = self.session.outcome() { + if let Err(e) = self.datastore.finalize_session(&self.scid, outcome).await { + warn!("finalize_session failed for scid={}: {e}", self.scid); + } + } + } + + fn execute_action(&mut self, action: SessionAction) { + match action { + SessionAction::FailHtlcs { failure_code } => { + for (_, reply_tx) in self.pending_htlcs.drain() { + let _ = reply_tx.send(HtlcResponse::Fail { failure_code }); + } + } + SessionAction::ForwardHtlcs { parts, channel_id } => { + // First time forwarding HTLCs, we cancel the collect timeout + // and start polling the channel for closure: + self.cancel_collect_timeout(); + self.start_channel_poll(channel_id.clone()); + for part in &parts { + if let Some(reply_tx) = self.pending_htlcs.remove(&part.htlc_id) { + let _ = reply_tx.send(HtlcResponse::Forward { + channel_id: channel_id.clone(), + fee_msat: part.fee_msat, + forward_msat: part.forward_msat, + }); + } + } + } + SessionAction::FundChannel { + peer_id, + channel_capacity_msat, + opening_fee_params, + } => { + let executor = self.executor.clone(); + let self_tx = self.self_send.clone(); + let datastore = self.datastore.clone(); + let scid = self.scid; + tokio::spawn(async move { + match executor + .fund_channel(peer_id, channel_capacity_msat, opening_fee_params) + .await + { + Ok((channel_id, funding_psbt)) => { + if let Err(e) = datastore + .update_session_funding(&scid, &channel_id, &funding_psbt) + .await + { + warn!("update_session_funding failed for scid={scid}: {e}"); + } + let _ = self_tx + .send(ActorInput::ChannelReady { + channel_id, + funding_psbt, + }) + .await; + } + Err(e) => { + warn!("fund_channel failed: {e}"); + let _ = self_tx.send(ActorInput::FundingFailed).await; + } + } + }); + } + SessionAction::FailSession => { + // Is basically a no-op as it is always accompanied with FailHtlcs. + let n = self.release_pending_htlcs(); + debug_assert_eq!(n, 0); + } + SessionAction::AbandonSession { + channel_id, + funding_psbt, + } => { + // Is also basically a no-op as all htlcs should have been + // already forwarded. + let n = self.release_pending_htlcs(); + debug_assert_eq!(n, 0); + + let executor = self.executor.clone(); + tokio::spawn(async move { + if let Err(e) = executor + .abandon_session(channel_id.clone(), funding_psbt.clone()) + .await + { + warn!( + "abandon_session failed (channel_id={}, funding_psbt={}): {}", + channel_id, funding_psbt, e + ); + } + }); + } + SessionAction::BroadcastFundingTx { + channel_id, + funding_psbt, + } => { + self.cancel_channel_poll(); + let executor = self.executor.clone(); + let self_tx = self.self_send.clone(); + let datastore = self.datastore.clone(); + let scid = self.scid; + tokio::spawn(async move { + match executor + .broadcast_tx(channel_id.clone(), funding_psbt.clone()) + .await + { + Ok(txid) => { + if let Err(e) = datastore + .update_session_funding_txid(&scid, &txid) + .await + { + warn!("update_session_funding_txid failed for scid={scid}: {e}"); + } + let _ = self_tx.send(ActorInput::FundingBroadcasted).await; + } + Err(e) => { + warn!( + "broadcast_tx failed (channel_id={}, funding_psbt={}): {}", + channel_id, funding_psbt, e + ); + } + } + }); + } + SessionAction::Disconnect => { + let executor = self.executor.clone(); + let peer_id = self.peer_id.clone(); + tokio::spawn(async move { + if let Err(e) = executor.disconnect(peer_id.clone()).await { + warn!("disconnect failed (peer_id={}): {}", peer_id, e); + } + }); + } + } + } + + fn release_pending_htlcs(&mut self) -> usize { + let n = self.pending_htlcs.iter().len(); + for (_, reply_tx) in self.pending_htlcs.drain() { + let _ = reply_tx.send(HtlcResponse::Continue); + } + n + } +} + +#[async_trait] +impl ActionExecutor for Arc { + async fn fund_channel( + &self, + peer_id: String, + channel_capacity_msat: Msat, + opening_fee_params: OpeningFeeParams, + ) -> Result<(String, String)> { + (**self) + .fund_channel(peer_id, channel_capacity_msat, opening_fee_params) + .await + } + + async fn abandon_session(&self, channel_id: String, funding_psbt: String) -> Result<()> { + (**self).abandon_session(channel_id, funding_psbt).await + } + + async fn broadcast_tx(&self, channel_id: String, funding_psbt: String) -> Result { + (**self).broadcast_tx(channel_id, funding_psbt).await + } + + async fn disconnect(&self, peer_id: String) -> Result<()> { + (**self).disconnect(peer_id).await + } + + async fn is_channel_alive(&self, channel_id: &str) -> Result { + (**self).is_channel_alive(channel_id).await + } +} + +#[async_trait] +impl DatastoreProvider for Arc { + async fn store_buy_request( + &self, + scid: &ShortChannelId, + peer_id: &bitcoin::secp256k1::PublicKey, + offer: &OpeningFeeParams, + expected_payment_size: &Option, + channel_capacity_msat: &Msat, + ) -> Result { + (**self) + .store_buy_request(scid, peer_id, offer, expected_payment_size, channel_capacity_msat) + .await + } + + async fn get_buy_request( + &self, + scid: &ShortChannelId, + ) -> Result { + (**self).get_buy_request(scid).await + } + + async fn del_buy_request(&self, scid: &ShortChannelId) -> Result<()> { + (**self).del_buy_request(scid).await + } + + async fn finalize_session( + &self, + scid: &ShortChannelId, + outcome: crate::proto::lsps2::SessionOutcome, + ) -> Result<()> { + (**self).finalize_session(scid, outcome).await + } + + async fn update_session_funding( + &self, + scid: &ShortChannelId, + channel_id: &str, + funding_psbt: &str, + ) -> Result<()> { + (**self) + .update_session_funding(scid, channel_id, funding_psbt) + .await + } + + async fn update_session_funding_txid( + &self, + scid: &ShortChannelId, + funding_txid: &str, + ) -> Result<()> { + (**self) + .update_session_funding_txid(scid, funding_txid) + .await + } + + async fn update_session_preimage( + &self, + scid: &ShortChannelId, + preimage: &str, + ) -> Result<()> { + (**self).update_session_preimage(scid, preimage).await + } +} From 9364d9eaa81fb2ce72dc194eed861ae9c3914fc6 Mon Sep 17 00:00:00 2001 From: Peter Neuroth Date: Thu, 19 Mar 2026 00:07:24 +0100 Subject: [PATCH 03/11] plugins(lsps2): add session manager and unify HTLC handling Add SessionManager that routes incoming HTLCs to the correct session actor by payment hash, replacing the previous handler-based approach. Reworks the policy plugin API and integrates the CLN RPC executor, unifies HTLC handling into the session FSM, and removes the now deprecated handler.rs. --- plugins/lsps-plugin/src/client.rs | 10 +- plugins/lsps-plugin/src/cln_adapters/rpc.rs | 470 +++++- plugins/lsps-plugin/src/core/lsps2/handler.rs | 1367 ----------------- plugins/lsps-plugin/src/core/lsps2/htlc.rs | 802 ---------- plugins/lsps-plugin/src/core/lsps2/manager.rs | 563 +++++++ plugins/lsps-plugin/src/core/lsps2/mod.rs | 4 +- .../lsps-plugin/src/core/lsps2/provider.rs | 37 +- plugins/lsps-plugin/src/core/lsps2/service.rs | 120 +- plugins/lsps-plugin/src/core/lsps2/session.rs | 128 +- plugins/lsps-plugin/src/service.rs | 246 ++- tests/plugins/lsps2_policy.py | 8 +- 11 files changed, 1397 insertions(+), 2358 deletions(-) delete mode 100644 plugins/lsps-plugin/src/core/lsps2/handler.rs delete mode 100644 plugins/lsps-plugin/src/core/lsps2/htlc.rs create mode 100644 plugins/lsps-plugin/src/core/lsps2/manager.rs diff --git a/plugins/lsps-plugin/src/client.rs b/plugins/lsps-plugin/src/client.rs index bdf3475fb0f2..ceb06391398f 100644 --- a/plugins/lsps-plugin/src/client.rs +++ b/plugins/lsps-plugin/src/client.rs @@ -17,11 +17,11 @@ use cln_lsps::{ transport::{MultiplexedTransport, PendingRequests}, }, proto::{ - lsps0::{Msat, LSPS0_MESSAGE_TYPE, LSP_FEATURE_BIT}, + lsps0::{Msat, LSP_FEATURE_BIT}, lsps2::{compute_opening_fee, Lsps2BuyResponse, Lsps2GetInfoResponse, OpeningFeeParams}, }, }; -use cln_plugin::{options, HookBuilder, HookFilter}; +use cln_plugin::options; use cln_rpc::{ model::{ requests::{ @@ -82,10 +82,7 @@ impl ClientState for State { #[tokio::main] async fn main() -> Result<(), anyhow::Error> { if let Some(plugin) = cln_plugin::Builder::new(tokio::io::stdin(), tokio::io::stdout()) - .hook_from_builder( - HookBuilder::new("custommsg", hooks::client_custommsg_hook) - .filters(vec![HookFilter::Int(i64::from(LSPS0_MESSAGE_TYPE))]), - ) + .hook("custommsg", hooks::client_custommsg_hook) .option(OPTION_ENABLED) .rpcmethod( "lsps-listprotocols", @@ -698,6 +695,7 @@ async fn on_openchannel( return Ok(serde_json::json!({ "result": "continue", "mindepth": 0, + "reserve": 0, })); } else { // Not a requested JIT-channel opening, continue. diff --git a/plugins/lsps-plugin/src/cln_adapters/rpc.rs b/plugins/lsps-plugin/src/cln_adapters/rpc.rs index 3471d5838d24..8c44c605c4c2 100644 --- a/plugins/lsps-plugin/src/cln_adapters/rpc.rs +++ b/plugins/lsps-plugin/src/cln_adapters/rpc.rs @@ -1,13 +1,16 @@ use crate::{ - core::lsps2::provider::{ - Blockheight, BlockheightProvider, DatastoreProvider, LightningProvider, Lsps2OfferProvider, + core::lsps2::{ + actor::ActionExecutor, + provider::{ + Blockheight, BlockheightProvider, DatastoreProvider, Lsps2PolicyProvider, + }, }, proto::{ lsps0::Msat, lsps2::{ - DatastoreEntry, Lsps2PolicyGetChannelCapacityRequest, - Lsps2PolicyGetChannelCapacityResponse, Lsps2PolicyGetInfoRequest, - Lsps2PolicyGetInfoResponse, OpeningFeeParams, + DatastoreEntry, FinalizedDatastoreEntry, Lsps2PolicyBuyRequest, Lsps2PolicyBuyResponse, + Lsps2PolicyGetInfoRequest, Lsps2PolicyGetInfoResponse, OpeningFeeParams, + SessionOutcome, }, }, }; @@ -17,20 +20,29 @@ use bitcoin::secp256k1::PublicKey; use cln_rpc::{ model::{ requests::{ - DatastoreMode, DatastoreRequest, DeldatastoreRequest, FundchannelRequest, - GetinfoRequest, ListdatastoreRequest, ListpeerchannelsRequest, + AddpsbtoutputRequest, CloseRequest, ConnectRequest, DatastoreMode, DatastoreRequest, + DeldatastoreRequest, DisconnectRequest, FundchannelCancelRequest, + FundchannelCompleteRequest, FundchannelStartRequest, + FundpsbtRequest, GetinfoRequest, ListdatastoreRequest, ListpeerchannelsRequest, + SendpsbtRequest, SignpsbtRequest, UnreserveinputsRequest, }, responses::ListdatastoreResponse, }, - primitives::{Amount, AmountOrAll, ChannelState, Sha256, ShortChannelId}, + primitives::{Amount, AmountOrAll, ChannelState, Feerate, Sha256, ShortChannelId}, ClnRpc, }; use core::fmt; +use log::warn; use serde::Serialize; use std::path::PathBuf; +use std::str::FromStr; +use std::time::Duration; pub const DS_MAIN_KEY: &'static str = "lsps"; pub const DS_SUB_KEY: &'static str = "lsps2"; +pub const DS_SESSIONS_KEY: &str = "sessions"; +pub const DS_ACTIVE_KEY: &str = "active"; +pub const DS_FINALIZED_KEY: &str = "finalized"; #[derive(Clone)] pub struct ClnApiRpc { @@ -43,61 +55,281 @@ impl ClnApiRpc { } async fn create_rpc(&self) -> Result { + // Note: Add retry and backoff, be nicer than just failing. ClnRpc::new(&self.rpc_path).await } + + async fn poll_channel_ready( + &self, + channel_id: &Sha256, + timeout: Duration, + interval: Duration, + ) -> Result<()> { + let deadline = tokio::time::Instant::now() + timeout; + loop { + if self.check_channel_normal(channel_id).await? { + return Ok(()); + } + if tokio::time::Instant::now() + interval > deadline { + anyhow::bail!( + "timed out waiting for channel {} to reach CHANNELD_NORMAL", + channel_id + ); + } + tokio::time::sleep(interval).await; + } + } + + async fn check_channel_normal(&self, channel_id: &Sha256) -> Result { + let mut rpc = self.create_rpc().await?; + let r = rpc + .call_typed(&ListpeerchannelsRequest { + channel_id: Some(*channel_id), + id: None, + short_channel_id: None, + }) + .await + .with_context(|| "calling listpeerchannels")?; + + Ok(r.channels + .first() + .is_some_and(|ch| ch.state == ChannelState::CHANNELD_NORMAL)) + } + + async fn cleanup_failed_funding(&self, peer_id: &PublicKey, psbt: &str) { + if let Err(e) = self.unreserve_inputs(psbt).await { + warn!("cleanup: unreserveinputs for psbt={psbt} failed: {e}"); + } + if let Err(e) = self.cancel_fundchannel(peer_id).await { + warn!("cleanup: fundchannel_cancel failed: {e}"); + } + } + + async fn unreserve_inputs(&self, psbt: &str) -> Result<()> { + let mut rpc = self.create_rpc().await?; + rpc.call_typed(&UnreserveinputsRequest { + reserve: None, + psbt: psbt.to_string(), + }) + .await + .with_context(|| "calling unreserveinputs")?; + Ok(()) + } + + async fn cancel_fundchannel(&self, peer_id: &PublicKey) -> Result<()> { + let mut rpc = self.create_rpc().await?; + rpc.call_typed(&FundchannelCancelRequest { + id: peer_id.to_owned(), + }) + .await + .with_context(|| "calling fundchannel_cancel")?; + Ok(()) + } + + async fn connect(&self, peer_id: String) -> Result<()> { + // Note: We could add a retry here. + let mut rpc = self.create_rpc().await?; + let _ = rpc + .call_typed(&ConnectRequest { + host: None, + port: None, + id: peer_id, + }) + .await + .with_context(|| "calling connect")?; + Ok(()) + } +} + +/// Converts msat to sat, rounding up to avoid underfunding. +fn msat_to_sat_ceil(msat: u64) -> u64 { + msat.div_ceil(1000) } #[async_trait] -impl LightningProvider for ClnApiRpc { - async fn fund_jit_channel( +impl ActionExecutor for ClnApiRpc { + async fn fund_channel( &self, - peer_id: &PublicKey, - amount: &Msat, - ) -> Result<(Sha256, String)> { + peer_id: String, + channel_size: Msat, + _opening_fee_params: OpeningFeeParams, + ) -> anyhow::Result<(String, String)> { + let pk = PublicKey::from_str(&peer_id) + .with_context(|| format!("parsing peer_id '{peer_id}'"))?; + let channel_sat = msat_to_sat_ceil(channel_size.msat()); + + self.connect(peer_id).await?; + let mut rpc = self.create_rpc().await?; - let res = rpc - .call_typed(&FundchannelRequest { + let start_res = rpc + .call_typed(&FundchannelStartRequest { + id: pk, + amount: Amount::from_sat(channel_sat), + mindepth: Some(0), + channel_type: Some(vec![12, 46, 50]), // zero_conf channel announce: Some(false), close_to: None, - compact_lease: None, feerate: None, - minconf: None, - mindepth: Some(0), push_msat: None, - request_amt: None, + reserve: Some(Amount::from_sat(0)), + }) + .await + .with_context(|| "calling fundchannel_start")?; + let funding_address = start_res.funding_address; + + // Reserve input and add to tx + let mut rpc = self.create_rpc().await?; + let fundpsbt_res = match rpc + .call_typed(&FundpsbtRequest { + satoshi: AmountOrAll::Amount(Amount::from_sat(channel_sat)), + feerate: Feerate::Normal, + startweight: 1000, + excess_as_change: Some(true), + locktime: None, + min_witness_weight: None, + minconf: None, + nonwrapped: None, + opening_anchor_channel: None, reserve: None, - channel_type: Some(vec![12, 46, 50]), - utxos: None, - amount: AmountOrAll::Amount(Amount::from_msat(amount.msat())), - id: peer_id.to_owned(), }) .await - .with_context(|| "calling fundchannel")?; - Ok((res.channel_id, res.txid)) + { + Ok(r) => r, + Err(e) => { + self.cancel_fundchannel(&pk).await.ok(); + return Err(anyhow::Error::new(e).context("calling fundpsbt")); + } + }; + + let addout_res = match rpc + .call_typed(&AddpsbtoutputRequest { + satoshi: Amount::from_sat(channel_sat), + initialpsbt: Some(fundpsbt_res.psbt.clone()), + destination: Some(funding_address), + locktime: None, + }) + .await + { + Ok(r) => r, + Err(e) => { + self.cleanup_failed_funding(&pk, &fundpsbt_res.psbt).await; + return Err(anyhow::Error::new(e).context("calling addpsbtoutput")); + } + }; + let psbt = addout_res.psbt; + + let complete_res = match rpc + .call_typed(&FundchannelCompleteRequest { + id: pk, + psbt: psbt.clone(), + withhold: Some(true), + }) + .await + { + Ok(r) => r, + Err(e) => { + self.cleanup_failed_funding(&pk, &psbt).await; + return Err(anyhow::Error::new(e).context("calling fundchannel_complete")); + } + }; + let channel_id = complete_res.channel_id; + + if let Err(e) = self + .poll_channel_ready( + &channel_id, + Duration::from_secs(120), + Duration::from_secs(1), + ) + .await + { + self.cleanup_failed_funding(&pk, &psbt).await; + return Err(e); + } + + Ok((channel_id.to_string(), psbt)) } - async fn is_channel_ready(&self, peer_id: &PublicKey, channel_id: &Sha256) -> Result { + async fn broadcast_tx( + &self, + _channel_id: String, + funding_psbt: String, + ) -> anyhow::Result { let mut rpc = self.create_rpc().await?; - let r = rpc - .call_typed(&ListpeerchannelsRequest { - channel_id: None, - id: Some(peer_id.to_owned()), - short_channel_id: None, + let sign_res = rpc + .call_typed(&SignpsbtRequest { + psbt: funding_psbt, + signonly: None, }) .await - .with_context(|| "calling listpeerchannels")?; + .with_context(|| "calling signpsbt")?; + let send_res = rpc + .call_typed(&SendpsbtRequest { + psbt: sign_res.signed_psbt, + reserve: None, + }) + .await + .with_context(|| "calling sendpsbt")?; + Ok(send_res.txid) + } - let chs = r - .channels - .iter() - .find(|&ch| ch.channel_id.is_some_and(|id| id == *channel_id)); - if let Some(ch) = chs { - if ch.state == ChannelState::CHANNELD_NORMAL { - return Ok(true); - } + async fn abandon_session( + &self, + channel_id: String, + funding_psbt: String, + ) -> anyhow::Result<()> { + let close_res = { + let mut rpc = self.create_rpc().await?; + rpc.call_typed(&CloseRequest { + destination: None, + fee_negotiation_step: None, + force_lease_closed: None, + unilateraltimeout: Some(1), // We didn't even broadcast the channel yet. + wrong_funding: None, + feerange: None, + id: channel_id.clone(), + }) + .await + .with_context(|| format!("calling close for channel_id={channel_id}")) + }; + + if let Err(e) = &close_res { + warn!("abandon_session: close failed for channel_id={channel_id}: {e}"); + } + + let unreserve_res = self.unreserve_inputs(&funding_psbt).await; + if let Err(e) = &unreserve_res { + warn!("abandon_session: unreserveinputs failed for funding_psbt={funding_psbt}: {e}"); } - return Ok(false); + match (close_res, unreserve_res) { + (Ok(_), Ok(())) => Ok(()), + (Err(close_err), Ok(())) => Err(close_err), + (Ok(_), Err(unreserve_err)) => Err(unreserve_err), + (Err(close_err), Err(unreserve_err)) => Err(anyhow::anyhow!( + "abandon_session failed for channel_id={channel_id}: close failed: {close_err}; unreserveinputs failed for funding_psbt={funding_psbt}: {unreserve_err}" + )), + } + } + + async fn disconnect(&self, peer_id: String) -> anyhow::Result<()> { + let pk = PublicKey::from_str(&peer_id) + .with_context(|| format!("parsing peer_id '{peer_id}'"))?; + let mut rpc = self.create_rpc().await?; + let _ = rpc + .call_typed(&DisconnectRequest { + id: pk, + force: None, + }) + .await + .with_context(|| "calling disconnect")?; + Ok(()) + } + + async fn is_channel_alive(&self, channel_id: &str) -> anyhow::Result { + let sha = channel_id + .parse::() + .with_context(|| format!("parsing channel_id '{channel_id}'"))?; + self.check_channel_normal(&sha).await } } @@ -109,6 +341,7 @@ impl DatastoreProvider for ClnApiRpc { peer_id: &PublicKey, opening_fee_params: &OpeningFeeParams, expected_payment_size: &Option, + channel_capacity_msat: &Msat, ) -> Result { let mut rpc = self.create_rpc().await?; #[derive(Serialize)] @@ -117,12 +350,28 @@ impl DatastoreProvider for ClnApiRpc { opening_fee_params: &'a OpeningFeeParams, #[serde(borrow)] expected_payment_size: &'a Option, + channel_capacity_msat: &'a Msat, + created_at: chrono::DateTime, + #[serde(skip_serializing_if = "Option::is_none")] + channel_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + funding_psbt: Option, + #[serde(skip_serializing_if = "Option::is_none")] + funding_txid: Option, + #[serde(skip_serializing_if = "Option::is_none")] + preimage: Option, } let ds = BorrowedDatastoreEntry { peer_id, opening_fee_params, expected_payment_size, + channel_capacity_msat, + created_at: chrono::Utc::now(), + channel_id: None, + funding_psbt: None, + funding_txid: None, + preimage: None, }; let json_str = serde_json::to_string(&ds)?; @@ -134,6 +383,8 @@ impl DatastoreProvider for ClnApiRpc { key: vec![ DS_MAIN_KEY.to_string(), DS_SUB_KEY.to_string(), + DS_SESSIONS_KEY.to_string(), + DS_ACTIVE_KEY.to_string(), scid.to_string(), ], }; @@ -152,6 +403,8 @@ impl DatastoreProvider for ClnApiRpc { let key = vec![ DS_MAIN_KEY.to_string(), DS_SUB_KEY.to_string(), + DS_SESSIONS_KEY.to_string(), + DS_ACTIVE_KEY.to_string(), scid.to_string(), ]; let res = rpc @@ -170,6 +423,8 @@ impl DatastoreProvider for ClnApiRpc { let key = vec![ DS_MAIN_KEY.to_string(), DS_SUB_KEY.to_string(), + DS_SESSIONS_KEY.to_string(), + DS_ACTIVE_KEY.to_string(), scid.to_string(), ]; @@ -182,11 +437,131 @@ impl DatastoreProvider for ClnApiRpc { Ok(()) } + + async fn finalize_session(&self, scid: &ShortChannelId, outcome: SessionOutcome) -> Result<()> { + let entry = match self.get_buy_request(scid).await { + Ok(e) => e, + Err(e) => { + warn!("finalize_session: active entry for scid={scid} already gone: {e}"); + return Ok(()); + } + }; + + let finalized = FinalizedDatastoreEntry { + entry, + outcome, + finalized_at: chrono::Utc::now(), + }; + let json_str = serde_json::to_string(&finalized)?; + + let mut rpc = self.create_rpc().await?; + let key = vec![ + DS_MAIN_KEY.to_string(), + DS_SUB_KEY.to_string(), + DS_SESSIONS_KEY.to_string(), + DS_FINALIZED_KEY.to_string(), + scid.to_string(), + ]; + rpc.call_typed(&DatastoreRequest { + generation: None, + hex: None, + mode: Some(DatastoreMode::MUST_CREATE), + string: Some(json_str), + key, + }) + .await + .with_context(|| "calling datastore for finalize_session")?; + + self.del_buy_request(scid).await?; + Ok(()) + } + + async fn update_session_funding( + &self, + scid: &ShortChannelId, + channel_id: &str, + funding_psbt: &str, + ) -> Result<()> { + let mut entry = self.get_buy_request(scid).await?; + entry.channel_id = Some(channel_id.to_string()); + entry.funding_psbt = Some(funding_psbt.to_string()); + let json_str = serde_json::to_string(&entry)?; + + let mut rpc = self.create_rpc().await?; + rpc.call_typed(&DatastoreRequest { + generation: None, + hex: None, + mode: Some(DatastoreMode::CREATE_OR_REPLACE), + string: Some(json_str), + key: vec![ + DS_MAIN_KEY.to_string(), + DS_SUB_KEY.to_string(), + DS_SESSIONS_KEY.to_string(), + DS_ACTIVE_KEY.to_string(), + scid.to_string(), + ], + }) + .await + .with_context(|| "calling datastore for update_session_funding")?; + Ok(()) + } + + async fn update_session_funding_txid( + &self, + scid: &ShortChannelId, + funding_txid: &str, + ) -> Result<()> { + let mut entry = self.get_buy_request(scid).await?; + entry.funding_txid = Some(funding_txid.to_string()); + let json_str = serde_json::to_string(&entry)?; + + let mut rpc = self.create_rpc().await?; + rpc.call_typed(&DatastoreRequest { + generation: None, + hex: None, + mode: Some(DatastoreMode::CREATE_OR_REPLACE), + string: Some(json_str), + key: vec![ + DS_MAIN_KEY.to_string(), + DS_SUB_KEY.to_string(), + DS_SESSIONS_KEY.to_string(), + DS_ACTIVE_KEY.to_string(), + scid.to_string(), + ], + }) + .await + .with_context(|| "calling datastore for update_session_funding_txid")?; + Ok(()) + } + + async fn update_session_preimage(&self, scid: &ShortChannelId, preimage: &str) -> Result<()> { + let mut entry = self.get_buy_request(scid).await?; + entry.preimage = Some(preimage.to_string()); + let json_str = serde_json::to_string(&entry)?; + + let mut rpc = self.create_rpc().await?; + rpc.call_typed(&DatastoreRequest { + generation: None, + hex: None, + mode: Some(DatastoreMode::CREATE_OR_REPLACE), + string: Some(json_str), + key: vec![ + DS_MAIN_KEY.to_string(), + DS_SUB_KEY.to_string(), + DS_SESSIONS_KEY.to_string(), + DS_ACTIVE_KEY.to_string(), + scid.to_string(), + ], + }) + .await + .with_context(|| "calling datastore for update_session_preimage")?; + Ok(()) + } } #[async_trait] -impl Lsps2OfferProvider for ClnApiRpc { - async fn get_offer( +impl Lsps2PolicyProvider for ClnApiRpc { + async fn get_info( &self, request: &Lsps2PolicyGetInfoRequest, ) -> Result { @@ -196,15 +571,12 @@ impl Lsps2OfferProvider for ClnApiRpc { .context("failed to call lsps2-policy-getpolicy") } - async fn get_channel_capacity( - &self, - params: &Lsps2PolicyGetChannelCapacityRequest, - ) -> Result { + async fn buy(&self, request: &Lsps2PolicyBuyRequest) -> Result { let mut rpc = self.create_rpc().await?; - rpc.call_raw("lsps2-policy-getchannelcapacity", params) + rpc.call_raw("lsps2-policy-buy", request) .await .map_err(anyhow::Error::new) - .with_context(|| "calling lsps2-policy-getchannelcapacity") + .with_context(|| "calling lsps2-policy-buy") } } diff --git a/plugins/lsps-plugin/src/core/lsps2/handler.rs b/plugins/lsps-plugin/src/core/lsps2/handler.rs deleted file mode 100644 index 88124788a62f..000000000000 --- a/plugins/lsps-plugin/src/core/lsps2/handler.rs +++ /dev/null @@ -1,1367 +0,0 @@ -use crate::{ - core::lsps2::service::Lsps2Handler, - lsps2::{ - cln::{HtlcAcceptedRequest, HtlcAcceptedResponse, TLV_FORWARD_AMT}, - DS_MAIN_KEY, DS_SUB_KEY, - }, - proto::{ - jsonrpc::{RpcError, RpcErrorExt as _}, - lsps0::{LSPS0RpcErrorExt, Msat, ShortChannelId}, - lsps2::{ - compute_opening_fee, - failure_codes::{TEMPORARY_CHANNEL_FAILURE, UNKNOWN_NEXT_PEER}, - DatastoreEntry, Lsps2BuyRequest, Lsps2BuyResponse, Lsps2GetInfoRequest, - Lsps2GetInfoResponse, Lsps2PolicyGetChannelCapacityRequest, - Lsps2PolicyGetChannelCapacityResponse, Lsps2PolicyGetInfoRequest, - Lsps2PolicyGetInfoResponse, OpeningFeeParams, PolicyOpeningFeeParams, Promise, - }, - }, -}; -use anyhow::{Context, Result as AnyResult}; -use async_trait::async_trait; -use bitcoin::{ - hashes::{sha256::Hash as Sha256, Hash as _}, - secp256k1::PublicKey, -}; -use chrono::Utc; -use cln_rpc::{ - model::{ - requests::{ - DatastoreMode, DatastoreRequest, DeldatastoreRequest, FundchannelRequest, - GetinfoRequest, ListdatastoreRequest, ListpeerchannelsRequest, - }, - responses::ListdatastoreResponse, - }, - primitives::{Amount, AmountOrAll, ChannelState}, - ClnRpc, -}; -use log::{debug, warn}; -use rand::{rng, Rng as _}; -use serde::Serialize; -use std::{fmt, path::PathBuf, sync::Arc, time::Duration}; - -const DEFAULT_CLTV_EXPIRY_DELTA: u32 = 144; - -#[derive(Clone)] -pub struct ClnApiRpc { - rpc_path: PathBuf, -} - -impl ClnApiRpc { - pub fn new(rpc_path: PathBuf) -> Self { - Self { rpc_path } - } - - async fn create_rpc(&self) -> AnyResult { - ClnRpc::new(&self.rpc_path).await - } -} - -#[async_trait] -impl LightningProvider for ClnApiRpc { - async fn fund_jit_channel( - &self, - peer_id: &PublicKey, - amount: &Msat, - ) -> AnyResult<(Sha256, String)> { - let mut rpc = self.create_rpc().await?; - let res = rpc - .call_typed(&FundchannelRequest { - announce: Some(false), - close_to: None, - compact_lease: None, - feerate: None, - minconf: None, - mindepth: Some(0), - push_msat: None, - request_amt: None, - reserve: None, - channel_type: Some(vec![12, 46, 50]), - utxos: None, - amount: AmountOrAll::Amount(Amount::from_msat(amount.msat())), - id: peer_id.to_owned(), - }) - .await - .with_context(|| "calling fundchannel")?; - Ok((res.channel_id, res.txid)) - } - - async fn is_channel_ready(&self, peer_id: &PublicKey, channel_id: &Sha256) -> AnyResult { - let mut rpc = self.create_rpc().await?; - let r = rpc - .call_typed(&ListpeerchannelsRequest { - id: Some(peer_id.to_owned()), - short_channel_id: None, - }) - .await - .with_context(|| "calling listpeerchannels")?; - - let chs = r - .channels - .iter() - .find(|&ch| ch.channel_id.is_some_and(|id| id == *channel_id)); - if let Some(ch) = chs { - if ch.state == ChannelState::CHANNELD_NORMAL { - return Ok(true); - } - } - - return Ok(false); - } -} - -#[async_trait] -impl DatastoreProvider for ClnApiRpc { - async fn store_buy_request( - &self, - scid: &ShortChannelId, - peer_id: &PublicKey, - opening_fee_params: &OpeningFeeParams, - expected_payment_size: &Option, - ) -> AnyResult { - let mut rpc = self.create_rpc().await?; - #[derive(Serialize)] - struct BorrowedDatastoreEntry<'a> { - peer_id: &'a PublicKey, - opening_fee_params: &'a OpeningFeeParams, - #[serde(borrow)] - expected_payment_size: &'a Option, - } - - let ds = BorrowedDatastoreEntry { - peer_id, - opening_fee_params, - expected_payment_size, - }; - let json_str = serde_json::to_string(&ds)?; - - let ds = DatastoreRequest { - generation: None, - hex: None, - mode: Some(DatastoreMode::MUST_CREATE), - string: Some(json_str), - key: vec![ - DS_MAIN_KEY.to_string(), - DS_SUB_KEY.to_string(), - scid.to_string(), - ], - }; - - let _ = rpc - .call_typed(&ds) - .await - .map_err(anyhow::Error::new) - .with_context(|| "calling datastore")?; - - Ok(true) - } - - async fn get_buy_request(&self, scid: &ShortChannelId) -> AnyResult { - let mut rpc = self.create_rpc().await?; - let key = vec![ - DS_MAIN_KEY.to_string(), - DS_SUB_KEY.to_string(), - scid.to_string(), - ]; - let res = rpc - .call_typed(&ListdatastoreRequest { - key: Some(key.clone()), - }) - .await - .with_context(|| "calling listdatastore")?; - - let (rec, _) = deserialize_by_key(&res, key)?; - Ok(rec) - } - - async fn del_buy_request(&self, scid: &ShortChannelId) -> AnyResult<()> { - let mut rpc = self.create_rpc().await?; - let key = vec![ - DS_MAIN_KEY.to_string(), - DS_SUB_KEY.to_string(), - scid.to_string(), - ]; - - let _ = rpc - .call_typed(&DeldatastoreRequest { - generation: None, - key, - }) - .await; - - Ok(()) - } -} - -#[async_trait] -impl Lsps2OfferProvider for ClnApiRpc { - async fn get_offer( - &self, - request: &Lsps2PolicyGetInfoRequest, - ) -> AnyResult { - let mut rpc = self.create_rpc().await?; - rpc.call_raw("lsps2-policy-getpolicy", request) - .await - .context("failed to call lsps2-policy-getpolicy") - } - - async fn get_channel_capacity( - &self, - params: &Lsps2PolicyGetChannelCapacityRequest, - ) -> AnyResult { - let mut rpc = self.create_rpc().await?; - rpc.call_raw("lsps2-policy-getchannelcapacity", params) - .await - .map_err(anyhow::Error::new) - .with_context(|| "calling lsps2-policy-getchannelcapacity") - } -} - -#[async_trait] -impl BlockheightProvider for ClnApiRpc { - async fn get_blockheight(&self) -> AnyResult { - let mut rpc = self.create_rpc().await?; - let info = rpc - .call_typed(&GetinfoRequest {}) - .await - .map_err(anyhow::Error::new) - .with_context(|| "calling getinfo")?; - Ok(info.blockheight) - } -} - -#[async_trait] -pub trait Lsps2OfferProvider: Send + Sync { - async fn get_offer( - &self, - request: &Lsps2PolicyGetInfoRequest, - ) -> AnyResult; - - async fn get_channel_capacity( - &self, - params: &Lsps2PolicyGetChannelCapacityRequest, - ) -> AnyResult; -} - -type Blockheight = u32; - -#[async_trait] -pub trait BlockheightProvider: Send + Sync { - async fn get_blockheight(&self) -> AnyResult; -} - -#[async_trait] -pub trait DatastoreProvider: Send + Sync { - async fn store_buy_request( - &self, - scid: &ShortChannelId, - peer_id: &PublicKey, - offer: &OpeningFeeParams, - expected_payment_size: &Option, - ) -> AnyResult; - - async fn get_buy_request(&self, scid: &ShortChannelId) -> AnyResult; - async fn del_buy_request(&self, scid: &ShortChannelId) -> AnyResult<()>; -} - -#[async_trait] -pub trait LightningProvider: Send + Sync { - async fn fund_jit_channel( - &self, - peer_id: &PublicKey, - amount: &Msat, - ) -> AnyResult<(Sha256, String)>; - - async fn is_channel_ready(&self, peer_id: &PublicKey, channel_id: &Sha256) -> AnyResult; -} - -pub struct Lsps2ServiceHandler { - pub api: Arc, - pub promise_secret: [u8; 32], -} - -impl Lsps2ServiceHandler { - pub fn new(api: Arc, promise_seret: &[u8; 32]) -> Self { - Lsps2ServiceHandler { - api, - promise_secret: promise_seret.to_owned(), - } - } -} - -async fn get_info_handler( - api: Arc, - secret: &[u8; 32], - request: &Lsps2GetInfoRequest, -) -> std::result::Result { - let res_data = api - .get_offer(&Lsps2PolicyGetInfoRequest { - token: request.token.clone(), - }) - .await - .map_err(|_| RpcError::internal_error("internal error"))?; - - if res_data.client_rejected { - return Err(RpcError::client_rejected("client was rejected")); - }; - - let opening_fee_params_menu = res_data - .policy_opening_fee_params_menu - .iter() - .map(|v| make_opening_fee_params(v, secret)) - .collect::, _>>()?; - - Ok(Lsps2GetInfoResponse { - opening_fee_params_menu, - }) -} - -fn make_opening_fee_params( - v: &PolicyOpeningFeeParams, - secret: &[u8; 32], -) -> Result { - let promise: Promise = v - .get_hmac_hex(secret) - .try_into() - .map_err(|_| RpcError::internal_error("internal error"))?; - Ok(OpeningFeeParams { - min_fee_msat: v.min_fee_msat, - proportional: v.proportional, - valid_until: v.valid_until, - min_lifetime: v.min_lifetime, - max_client_to_self_delay: v.max_client_to_self_delay, - min_payment_size_msat: v.min_payment_size_msat, - max_payment_size_msat: v.max_payment_size_msat, - promise, - }) -} - -#[async_trait] -impl Lsps2Handler - for Lsps2ServiceHandler -{ - async fn handle_get_info( - &self, - request: Lsps2GetInfoRequest, - ) -> std::result::Result { - get_info_handler(self.api.clone(), &self.promise_secret, &request).await - } - - async fn handle_buy( - &self, - peer_id: PublicKey, - request: Lsps2BuyRequest, - ) -> core::result::Result { - let fee_params = request.opening_fee_params; - - // FIXME: In the future we should replace the \`None\` with a meaningful - // value that reflects the inbound capacity for this node from the - // public network for a better pre-condition check on the payment_size. - fee_params.validate(&self.promise_secret, request.payment_size_msat, None)?; - - // Generate a tmp scid to identify jit channel request in htlc. - let blockheight = self - .api - .get_blockheight() - .await - .map_err(|_| RpcError::internal_error("internal error"))?; - - // FIXME: Future task: Check that we don't conflict with any jit scid we - // already handed out -> Check datastore entries. - let jit_scid = ShortChannelId::from(generate_jit_scid(blockheight)); - - let ok = self - .api - .store_buy_request(&jit_scid, &peer_id, &fee_params, &request.payment_size_msat) - .await - .map_err(|_| RpcError::internal_error("internal error"))?; - - if !ok { - return Err(RpcError::internal_error("internal error"))?; - } - - Ok(Lsps2BuyResponse { - jit_channel_scid: jit_scid, - // We can make this configurable if necessary. - lsp_cltv_expiry_delta: DEFAULT_CLTV_EXPIRY_DELTA, - // We can implement the other mode later on as we might have to do - // some additional work on core-lightning to enable this. - client_trusts_lsp: false, - }) - } -} - -fn generate_jit_scid(best_blockheigt: u32) -> u64 { - let mut rng = rng(); - let block = best_blockheigt + 6; // Approx 1 hour in the future and should avoid collision with confirmed channels - let tx_idx: u32 = rng.random_range(0..5000); - let output_idx: u16 = rng.random_range(0..10); - - ((block as u64) << 40) | ((tx_idx as u64) << 16) | (output_idx as u64) -} - -pub struct HtlcAcceptedHookHandler { - api: A, - htlc_minimum_msat: u64, - backoff_listpeerchannels: Duration, -} - -impl HtlcAcceptedHookHandler { - pub fn new(api: A, htlc_minimum_msat: u64) -> Self { - Self { - api, - htlc_minimum_msat, - backoff_listpeerchannels: Duration::from_secs(10), - } - } -} - -impl HtlcAcceptedHookHandler { - pub async fn handle(&self, req: HtlcAcceptedRequest) -> AnyResult { - let scid = match req.onion.short_channel_id { - Some(scid) => scid, - None => { - // We are the final destination of this htlc. - return Ok(HtlcAcceptedResponse::continue_(None, None, None)); - } - }; - - // A) Is this SCID one that we care about? - let ds_rec = match self.api.get_buy_request(&scid).await { - Ok(rec) => rec, - Err(_) => { - return Ok(HtlcAcceptedResponse::continue_(None, None, None)); - } - }; - - // Fixme: Check that we don't have a channel yet with the peer that we await to - // become READY to use. - // --- - - // Fixme: We only accept no-mpp for now, mpp and other flows will be added later on - // Fixme: We continue mpp for now to let the test mock handle the htlc, as we need - // to test the client implementation for mpp payments. - if ds_rec.expected_payment_size.is_some() { - warn!("mpp payments are not implemented yet"); - return Ok(HtlcAcceptedResponse::continue_(None, None, None)); - // return Ok(HtlcAcceptedResponse::fail( - // Some(UNKNOWN_NEXT_PEER.to_string()), - // None, - // )); - } - - // B) Is the fee option menu still valid? - let now = Utc::now(); - if now >= ds_rec.opening_fee_params.valid_until { - // Not valid anymore, remove from DS and fail HTLC. - let _ = self.api.del_buy_request(&scid).await; - return Ok(HtlcAcceptedResponse::fail( - Some(TEMPORARY_CHANNEL_FAILURE.to_string()), - None, - )); - } - - // C) Is the amount in the boundaries of the fee menu? - if req.htlc.amount_msat.msat() < ds_rec.opening_fee_params.min_fee_msat.msat() - || req.htlc.amount_msat.msat() > ds_rec.opening_fee_params.max_payment_size_msat.msat() - { - // No! reject the HTLC. - debug!("amount_msat for scid: {}, was too low or to high", scid); - return Ok(HtlcAcceptedResponse::fail( - Some(UNKNOWN_NEXT_PEER.to_string()), - None, - )); - } - - // D) Check that the amount_msat covers the opening fee (only for non-mpp right now) - let opening_fee = if let Some(opening_fee) = compute_opening_fee( - req.htlc.amount_msat.msat(), - ds_rec.opening_fee_params.min_fee_msat.msat(), - ds_rec.opening_fee_params.proportional.ppm() as u64, - ) { - if opening_fee + self.htlc_minimum_msat >= req.htlc.amount_msat.msat() { - debug!("amount_msat for scid: {}, does not cover opening fee", scid); - return Ok(HtlcAcceptedResponse::fail( - Some(UNKNOWN_NEXT_PEER.to_string()), - None, - )); - } - opening_fee - } else { - // The computation overflowed. - debug!("amount_msat for scid: {}, was too low or to high", scid); - return Ok(HtlcAcceptedResponse::fail( - Some(UNKNOWN_NEXT_PEER.to_string()), - None, - )); - }; - - // E) We made it, open a channel to the peer. - let ch_cap_req = Lsps2PolicyGetChannelCapacityRequest { - opening_fee_params: ds_rec.opening_fee_params, - init_payment_size: Msat::from_msat(req.htlc.amount_msat.msat()), - scid, - }; - let ch_cap_res = match self.api.get_channel_capacity(&ch_cap_req).await { - Ok(r) => r, - Err(e) => { - warn!("failed to get channel capacity for scid {}: {}", scid, e); - return Ok(HtlcAcceptedResponse::fail( - Some(UNKNOWN_NEXT_PEER.to_string()), - None, - )); - } - }; - - let cap = match ch_cap_res.channel_capacity_msat { - Some(c) => Msat::from_msat(c), - None => { - debug!("policy giver does not allow channel for scid {}", scid); - return Ok(HtlcAcceptedResponse::fail( - Some(UNKNOWN_NEXT_PEER.to_string()), - None, - )); - } - }; - - // We take the policy-giver seriously, if the capacity is too low, we - // still try to open the channel. - // Fixme: We may check that the capacity is ge than the - // (amount_msat - opening fee) in the future. - // Fixme: Make this configurable, maybe return the whole request from - // the policy giver? - let channel_id = match self.api.fund_jit_channel(&ds_rec.peer_id, &cap).await { - Ok((channel_id, _)) => channel_id, - Err(_) => { - return Ok(HtlcAcceptedResponse::fail( - Some(UNKNOWN_NEXT_PEER.to_string()), - None, - )); - } - }; - - // F) Wait for the peer to send `channel_ready`. - // Fixme: Use event to check for channel ready, - // Fixme: Check for htlc timeout if peer refuses to send "ready". - // Fixme: handle unexpected channel states. - loop { - match self - .api - .is_channel_ready(&ds_rec.peer_id, &channel_id) - .await - { - Ok(true) => break, - Ok(false) | Err(_) => tokio::time::sleep(self.backoff_listpeerchannels).await, - }; - } - - // G) We got a working channel, deduct fee and forward htlc. - let deducted_amt_msat = req.htlc.amount_msat.msat() - opening_fee; - let mut payload = req.onion.payload.clone(); - payload.set_tu64(TLV_FORWARD_AMT, deducted_amt_msat); - - // It is okay to unwrap the next line as we do not have duplicate entries. - let payload_bytes = payload.to_bytes().unwrap(); - debug!("about to send payload: {:02x?}", &payload_bytes); - - let mut extra_tlvs = req.htlc.extra_tlvs.unwrap_or_default().clone(); - extra_tlvs.set_u64(65537, opening_fee); - let extra_tlvs_bytes = extra_tlvs.to_bytes().unwrap(); - debug!("extra_tlv: {:02x?}", extra_tlvs_bytes); - - Ok(HtlcAcceptedResponse::continue_( - Some(payload_bytes), - Some(channel_id.as_byte_array().to_vec()), - Some(extra_tlvs_bytes), - )) - } -} - -#[derive(Debug)] -pub enum DsError { - /// No datastore entry with this exact key. - NotFound { key: Vec }, - /// Entry existed but had neither `string` nor `hex`. - MissingValue { key: Vec }, - /// JSON parse failed (from `string` or decoded `hex`). - JsonParse { - key: Vec, - source: serde_json::Error, - }, - /// Hex decode failed. - HexDecode { - key: Vec, - source: hex::FromHexError, - }, -} - -impl fmt::Display for DsError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - DsError::NotFound { key } => write!(f, "no datastore entry for key {:?}", key), - DsError::MissingValue { key } => write!( - f, - "datastore entry had neither `string` nor `hex` for key {:?}", - key - ), - DsError::JsonParse { key, source } => { - write!(f, "failed to parse JSON at key {:?}: {}", key, source) - } - DsError::HexDecode { key, source } => { - write!(f, "failed to decode hex at key {:?}: {}", key, source) - } - } - } -} - -impl std::error::Error for DsError {} - -pub fn deserialize_by_key( - resp: &ListdatastoreResponse, - key: K, -) -> std::result::Result<(DatastoreEntry, Option), DsError> -where - K: AsRef<[String]>, -{ - let wanted: &[String] = key.as_ref(); - - let ds = resp - .datastore - .iter() - .find(|d| d.key.as_slice() == wanted) - .ok_or_else(|| DsError::NotFound { - key: wanted.to_vec(), - })?; - - // Prefer `string`, fall back to `hex` - if let Some(s) = &ds.string { - let value = serde_json::from_str::(s).map_err(|e| DsError::JsonParse { - key: ds.key.clone(), - source: e, - })?; - return Ok((value, ds.generation)); - } - - if let Some(hx) = &ds.hex { - let bytes = hex::decode(hx).map_err(|e| DsError::HexDecode { - key: ds.key.clone(), - source: e, - })?; - let value = - serde_json::from_slice::(&bytes).map_err(|e| DsError::JsonParse { - key: ds.key.clone(), - source: e, - })?; - return Ok((value, ds.generation)); - } - - Err(DsError::MissingValue { - key: ds.key.clone(), - }) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - lsps2::cln::{tlv::TlvStream, HtlcAcceptedResult}, - proto::{ - jsonrpc, - lsps0::Ppm, - lsps2::{Lsps2PolicyGetInfoResponse, PolicyOpeningFeeParams}, - }, - }; - use anyhow::bail; - use chrono::{TimeZone, Utc}; - use cln_rpc::primitives::{Amount, PublicKey}; - use cln_rpc::RpcError as ClnRpcError; - use std::sync::{Arc, Mutex}; - - const PUBKEY: [u8; 33] = [ - 0x02, 0x79, 0xbe, 0x66, 0x7e, 0xf9, 0xdc, 0xbb, 0xac, 0x55, 0xa0, 0x62, 0x95, 0xce, 0x87, - 0x0b, 0x07, 0x02, 0x9b, 0xfc, 0xdb, 0x2d, 0xce, 0x28, 0xd9, 0x59, 0xf2, 0x81, 0x5b, 0x16, - 0xf8, 0x17, 0x98, - ]; - - fn create_peer_id() -> PublicKey { - PublicKey::from_slice(&PUBKEY).expect("Valid pubkey") - } - - /// Build a pair: policy params + buy params with a Promise derived from `secret` - fn params_with_promise(secret: &[u8; 32]) -> (PolicyOpeningFeeParams, OpeningFeeParams) { - let policy = PolicyOpeningFeeParams { - min_fee_msat: Msat(2_000), - proportional: Ppm(10_000), - valid_until: Utc.with_ymd_and_hms(2100, 1, 1, 0, 0, 0).unwrap(), - min_lifetime: 1000, - max_client_to_self_delay: 42, - min_payment_size_msat: Msat(1_000_000), - max_payment_size_msat: Msat(100_000_000), - }; - let hex = policy.get_hmac_hex(secret); - let promise: Promise = hex.try_into().expect("hex->Promise"); - let buy = OpeningFeeParams { - min_fee_msat: policy.min_fee_msat, - proportional: policy.proportional, - valid_until: policy.valid_until, - min_lifetime: policy.min_lifetime, - max_client_to_self_delay: policy.max_client_to_self_delay, - min_payment_size_msat: policy.min_payment_size_msat, - max_payment_size_msat: policy.max_payment_size_msat, - promise, - }; - (policy, buy) - } - - #[derive(Clone, Default)] - struct FakeCln { - lsps2_getpolicy_response: Arc>>, - lsps2_getpolicy_error: Arc>>, - blockheight_response: Option, - blockheight_error: Arc>>, - store_buy_request_response: bool, - get_buy_request_response: Arc>>, - get_buy_request_error: Arc>>, - fund_channel_error: Arc>>, - fund_channel_response: Arc>>, - lsps2_getchannelcapacity_response: - Arc>>, - lsps2_getchannelcapacity_error: Arc>>, - } - - #[async_trait] - impl Lsps2OfferProvider for FakeCln { - async fn get_offer( - &self, - _request: &Lsps2PolicyGetInfoRequest, - ) -> AnyResult { - if let Some(err) = self.lsps2_getpolicy_error.lock().unwrap().take() { - return Err(anyhow::Error::new(err).context("from fake api")); - }; - if let Some(res) = self.lsps2_getpolicy_response.lock().unwrap().take() { - return Ok(Lsps2PolicyGetInfoResponse { - policy_opening_fee_params_menu: res.policy_opening_fee_params_menu, - client_rejected: false, - }); - }; - panic!("No lsps2 response defined"); - } - - async fn get_channel_capacity( - &self, - _params: &Lsps2PolicyGetChannelCapacityRequest, - ) -> AnyResult { - if let Some(err) = self.lsps2_getchannelcapacity_error.lock().unwrap().take() { - return Err(anyhow::Error::new(err).context("from fake api")); - } - if let Some(res) = self - .lsps2_getchannelcapacity_response - .lock() - .unwrap() - .take() - { - return Ok(res); - } - panic!("No lsps2 getchannelcapacity response defined"); - } - } - - #[async_trait] - impl BlockheightProvider for FakeCln { - async fn get_blockheight(&self) -> AnyResult { - if let Some(err) = self.blockheight_error.lock().unwrap().take() { - return Err(err); - }; - if let Some(blockheight) = self.blockheight_response { - return Ok(blockheight); - }; - panic!("No cln getinfo response defined"); - } - } - - #[async_trait] - impl DatastoreProvider for FakeCln { - async fn store_buy_request( - &self, - _scid: &ShortChannelId, - _peer_id: &PublicKey, - _offer: &OpeningFeeParams, - _payment_size_msat: &Option, - ) -> AnyResult { - Ok(self.store_buy_request_response) - } - - async fn get_buy_request(&self, _scid: &ShortChannelId) -> AnyResult { - if let Some(err) = self.get_buy_request_error.lock().unwrap().take() { - return Err(err); - } - if let Some(res) = self.get_buy_request_response.lock().unwrap().take() { - return Ok(res); - } else { - bail!("request not found") - } - } - - async fn del_buy_request(&self, _scid: &ShortChannelId) -> AnyResult<()> { - Ok(()) - } - } - - #[async_trait] - impl LightningProvider for FakeCln { - async fn fund_jit_channel( - &self, - _peer_id: &PublicKey, - _amount: &Msat, - ) -> AnyResult<(Sha256, String)> { - if let Some(err) = self.fund_channel_error.lock().unwrap().take() { - return Err(err); - } - if let Some(res) = self.fund_channel_response.lock().unwrap().take() { - return Ok(res); - } else { - bail!("request not found") - } - } - - async fn is_channel_ready( - &self, - _peer_id: &PublicKey, - _channel_id: &Sha256, - ) -> AnyResult { - Ok(true) - } - } - - fn create_test_htlc_request( - scid: Option, - amount_msat: u64, - ) -> HtlcAcceptedRequest { - let payload = TlvStream::default(); - - HtlcAcceptedRequest { - onion: crate::lsps2::cln::Onion { - short_channel_id: scid, - payload, - next_onion: vec![], - forward_msat: None, - outgoing_cltv_value: None, - shared_secret: vec![], - total_msat: None, - type_: None, - }, - htlc: crate::lsps2::cln::Htlc { - amount_msat: Amount::from_msat(amount_msat), - cltv_expiry: 100, - cltv_expiry_relative: 10, - payment_hash: vec![], - extra_tlvs: None, - short_channel_id: ShortChannelId::from(123456789u64), - id: 0, - }, - forward_to: None, - } - } - - fn create_test_datastore_entry( - peer_id: PublicKey, - expected_payment_size: Option, - ) -> DatastoreEntry { - let (_, policy) = params_with_promise(&[0u8; 32]); - DatastoreEntry { - peer_id, - opening_fee_params: policy, - expected_payment_size, - } - } - - fn test_promise_secret() -> [u8; 32] { - [0x42; 32] - } - - #[tokio::test] - async fn test_successful_get_info() { - let promise_secret = test_promise_secret(); - let params = Lsps2PolicyGetInfoResponse { - client_rejected: false, - policy_opening_fee_params_menu: vec![PolicyOpeningFeeParams { - min_fee_msat: Msat(2000), - proportional: Ppm(10000), - valid_until: Utc.with_ymd_and_hms(1970, 1, 1, 0, 0, 0).unwrap(), - min_lifetime: 1000, - max_client_to_self_delay: 42, - min_payment_size_msat: Msat(1000000), - max_payment_size_msat: Msat(100000000), - }], - }; - let promise = params.policy_opening_fee_params_menu[0].get_hmac_hex(&promise_secret); - let fake = FakeCln::default(); - *fake.lsps2_getpolicy_response.lock().unwrap() = Some(params); - - let handler = Lsps2ServiceHandler { - api: Arc::new(fake), - promise_secret, - }; - - let request = Lsps2GetInfoRequest { token: None }; - let result = handler.handle_get_info(request).await.unwrap(); - - assert_eq!( - result.opening_fee_params_menu[0].min_payment_size_msat, - Msat(1000000) - ); - assert_eq!( - result.opening_fee_params_menu[0].max_payment_size_msat, - Msat(100000000) - ); - assert_eq!( - result.opening_fee_params_menu[0].promise, - promise.try_into().unwrap() - ); - } - - #[tokio::test] - async fn test_get_info_rpc_error_handling() { - let promise_secret = test_promise_secret(); - let fake = FakeCln::default(); - *fake.lsps2_getpolicy_error.lock().unwrap() = Some(ClnRpcError { - code: Some(-1), - message: "not found".to_string(), - data: None, - }); - - let handler = Lsps2ServiceHandler { - api: Arc::new(fake), - promise_secret, - }; - - let request = Lsps2GetInfoRequest { token: None }; - let result = handler.handle_get_info(request).await; - - assert!(result.is_err()); - let error = result.unwrap_err(); - assert_eq!(error.code, jsonrpc::INTERNAL_ERROR); - assert!(error.message.contains("internal error")); - } - - #[tokio::test] - async fn buy_ok_fixed_amount() { - let promise_secret = test_promise_secret(); - let mut fake = FakeCln::default(); - fake.blockheight_response = Some(900_000); - fake.store_buy_request_response = true; - - let handler = Lsps2ServiceHandler { - api: Arc::new(fake), - promise_secret, - }; - - let (_policy, buy) = params_with_promise(&promise_secret); - - // Set payment_size_msat => "MPP+fixed-invoice" mode. - let request = Lsps2BuyRequest { - opening_fee_params: buy, - payment_size_msat: Some(Msat(2_000_000)), - }; - let peer_id = create_peer_id(); - - let result = handler.handle_buy(peer_id, request).await.unwrap(); - - assert_eq!(result.lsp_cltv_expiry_delta, DEFAULT_CLTV_EXPIRY_DELTA); - assert!(!result.client_trusts_lsp); - assert!(result.jit_channel_scid.to_u64() > 0); - } - - #[tokio::test] - async fn buy_ok_variable_amount_no_payment_size() { - let promise_secret = test_promise_secret(); - let mut fake = FakeCln::default(); - fake.blockheight_response = Some(900_100); - fake.store_buy_request_response = true; - - let handler = Lsps2ServiceHandler { - api: Arc::new(fake), - promise_secret, - }; - - let (_policy, buy) = params_with_promise(&promise_secret); - - // No payment_size_msat => "no-MPP+var-invoice" mode. - let request = Lsps2BuyRequest { - opening_fee_params: buy, - payment_size_msat: None, - }; - let peer_id = create_peer_id(); - - let result = handler.handle_buy(peer_id, request).await; - - assert!(result.is_ok()); - } - - #[tokio::test] - async fn buy_rejects_invalid_promise_or_past_valid_until_with_201() { - let promise_secret = test_promise_secret(); - let handler = Lsps2ServiceHandler { - api: Arc::new(FakeCln::default()), - promise_secret, - }; - - // Case A: wrong promise (derive with different secret) - let (_policy_wrong, mut buy_wrong) = params_with_promise(&[9u8; 32]); - buy_wrong.valid_until = Utc.with_ymd_and_hms(2100, 1, 1, 0, 0, 0).unwrap(); // future, so only promise is wrong - let req_wrong = Lsps2BuyRequest { - opening_fee_params: buy_wrong, - payment_size_msat: Some(Msat(2_000_000)), - }; - let peer_id = create_peer_id(); - - let err1 = handler.handle_buy(peer_id, req_wrong).await.unwrap_err(); - assert_eq!(err1.code, 201); - - // Case B: past valid_until - let (_policy, mut buy_past) = params_with_promise(&promise_secret); - buy_past.valid_until = Utc.with_ymd_and_hms(1970, 1, 1, 0, 0, 0).unwrap(); // past - let req_past = Lsps2BuyRequest { - opening_fee_params: buy_past, - payment_size_msat: Some(Msat(2_000_000)), - }; - let err2 = handler.handle_buy(peer_id, req_past).await.unwrap_err(); - assert_eq!(err2.code, 201); - } - - #[tokio::test] - async fn buy_rejects_when_opening_fee_ge_payment_size_with_202() { - let promise_secret = test_promise_secret(); - let handler = Lsps2ServiceHandler { - api: Arc::new(FakeCln::default()), - promise_secret, - }; - - // Make min_fee already >= payment_size to trigger 202 - let policy = PolicyOpeningFeeParams { - min_fee_msat: Msat(10_000), - proportional: Ppm(0), // no extra percentage - valid_until: Utc.with_ymd_and_hms(2100, 1, 1, 0, 0, 0).unwrap(), - min_lifetime: 1000, - max_client_to_self_delay: 42, - min_payment_size_msat: Msat(1), - max_payment_size_msat: Msat(u64::MAX / 2), - }; - let hex = policy.get_hmac_hex(&promise_secret); - let promise: Promise = hex.try_into().unwrap(); - let buy = OpeningFeeParams { - min_fee_msat: policy.min_fee_msat, - proportional: policy.proportional, - valid_until: policy.valid_until, - min_lifetime: policy.min_lifetime, - max_client_to_self_delay: policy.max_client_to_self_delay, - min_payment_size_msat: policy.min_payment_size_msat, - max_payment_size_msat: policy.max_payment_size_msat, - promise, - }; - - let request = Lsps2BuyRequest { - opening_fee_params: buy, - payment_size_msat: Some(Msat(9_999)), // strictly less than min_fee => opening_fee >= payment_size - }; - let peer_id = create_peer_id(); - let err = handler.handle_buy(peer_id, request).await.unwrap_err(); - - assert_eq!(err.code, 202); - } - - #[tokio::test] - async fn buy_rejects_on_fee_overflow_with_203() { - let promise_secret = test_promise_secret(); - let handler = Lsps2ServiceHandler { - api: Arc::new(FakeCln::default()), - promise_secret, - }; - - // Choose values likely to overflow if multiplication isn't checked: - // opening_fee = min_fee + payment_size * proportional / 1_000_000 - let policy = PolicyOpeningFeeParams { - min_fee_msat: Msat(u64::MAX / 2), - proportional: Ppm(u32::MAX), // 4_294_967_295 ppm - valid_until: Utc.with_ymd_and_hms(2100, 1, 1, 0, 0, 0).unwrap(), - min_lifetime: 1000, - max_client_to_self_delay: 42, - min_payment_size_msat: Msat(1), - max_payment_size_msat: Msat(u64::MAX), - }; - let hex = policy.get_hmac_hex(&promise_secret); - let promise: Promise = hex.try_into().unwrap(); - let buy = OpeningFeeParams { - min_fee_msat: policy.min_fee_msat, - proportional: policy.proportional, - valid_until: policy.valid_until, - min_lifetime: policy.min_lifetime, - max_client_to_self_delay: policy.max_client_to_self_delay, - min_payment_size_msat: policy.min_payment_size_msat, - max_payment_size_msat: policy.max_payment_size_msat, - promise, - }; - - let request = Lsps2BuyRequest { - opening_fee_params: buy, - payment_size_msat: Some(Msat(u64::MAX / 2)), - }; - let peer_id = create_peer_id(); - let err = handler.handle_buy(peer_id, request).await.unwrap_err(); - - assert_eq!(err.code, 203); - } - #[tokio::test] - async fn test_htlc_no_scid_continues() { - let fake = FakeCln::default(); - let handler = HtlcAcceptedHookHandler::new(fake, 1000); - - // HTLC with no short_channel_id (final destination) - let req = create_test_htlc_request(None, 1000000); - - let result = handler.handle(req).await.unwrap(); - assert_eq!(result.result, HtlcAcceptedResult::Continue); - } - - #[tokio::test] - async fn test_htlc_unknown_scid_continues() { - let fake = FakeCln::default(); - - let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); - let scid = ShortChannelId::from(123456789u64); - - let req = create_test_htlc_request(Some(scid), 1000000); - - let result = handler.handle(req).await.unwrap(); - assert_eq!(result.result, HtlcAcceptedResult::Continue); - } - - #[tokio::test] - async fn test_htlc_expired_fee_menu_fails() { - let fake = FakeCln::default(); - let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); - let peer_id = create_peer_id(); - let scid = ShortChannelId::from(123456789u64); - - // Create datastore entry with expired fee menu - let mut ds_entry = create_test_datastore_entry(peer_id, None); - ds_entry.opening_fee_params.valid_until = - Utc.with_ymd_and_hms(1970, 1, 1, 0, 0, 0).unwrap(); // expired - - *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); - - let req = create_test_htlc_request(Some(scid), 1000000); - - let result = handler.handle(req).await.unwrap(); - assert_eq!(result.result, HtlcAcceptedResult::Fail); - assert_eq!( - result.failure_message.unwrap(), - TEMPORARY_CHANNEL_FAILURE.to_string() - ); - } - - #[tokio::test] - async fn test_htlc_amount_too_low_fails() { - let fake = FakeCln::default(); - let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); - let peer_id = create_peer_id(); - let scid = ShortChannelId::from(123456789u64); - - let ds_entry = create_test_datastore_entry(peer_id, None); - *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); - - // HTLC amount below minimum - let req = create_test_htlc_request(Some(scid), 100); - - let result = handler.handle(req).await.unwrap(); - assert_eq!(result.result, HtlcAcceptedResult::Fail); - assert_eq!( - result.failure_message.unwrap(), - UNKNOWN_NEXT_PEER.to_string() - ); - } - - #[tokio::test] - async fn test_htlc_amount_too_high_fails() { - let fake = FakeCln::default(); - let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); - let peer_id = create_peer_id(); - let scid = ShortChannelId::from(123456789u64); - - let ds_entry = create_test_datastore_entry(peer_id, None); - *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); - - // HTLC amount above maximum - let req = create_test_htlc_request(Some(scid), 200_000_000); - - let result = handler.handle(req).await.unwrap(); - assert_eq!(result.result, HtlcAcceptedResult::Fail); - assert_eq!( - result.failure_message.unwrap(), - UNKNOWN_NEXT_PEER.to_string() - ); - } - - #[tokio::test] - async fn test_htlc_amount_doesnt_cover_fee_fails() { - let fake = FakeCln::default(); - let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); - let peer_id = create_peer_id(); - let scid = ShortChannelId::from(123456789u64); - - let ds_entry = create_test_datastore_entry(peer_id, None); - *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); - - // HTLC amount just barely covers minimum fee but not minimum HTLC - let req = create_test_htlc_request(Some(scid), 2500); // min_fee is 2000, htlc_minimum is 1000 - - let result = handler.handle(req).await.unwrap(); - assert_eq!(result.result, HtlcAcceptedResult::Fail); - assert_eq!( - result.failure_message.unwrap(), - UNKNOWN_NEXT_PEER.to_string() - ); - } - - #[tokio::test] - async fn test_htlc_channel_capacity_request_fails() { - let fake = FakeCln::default(); - let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); - let peer_id = create_peer_id(); - let scid = ShortChannelId::from(123456789u64); - - let ds_entry = create_test_datastore_entry(peer_id, None); - *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); - - *fake.lsps2_getchannelcapacity_error.lock().unwrap() = Some(ClnRpcError { - code: Some(-1), - message: "capacity check failed".to_string(), - data: None, - }); - - let req = create_test_htlc_request(Some(scid), 10_000_000); - - let result = handler.handle(req).await.unwrap(); - assert_eq!(result.result, HtlcAcceptedResult::Fail); - assert_eq!( - result.failure_message.unwrap(), - UNKNOWN_NEXT_PEER.to_string() - ); - } - - #[tokio::test] - async fn test_htlc_policy_denies_channel() { - let fake = FakeCln::default(); - let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); - let peer_id = create_peer_id(); - let scid = ShortChannelId::from(123456789u64); - - let ds_entry = create_test_datastore_entry(peer_id, None); - *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); - - // Policy response with no channel capacity (denied) - *fake.lsps2_getchannelcapacity_response.lock().unwrap() = - Some(Lsps2PolicyGetChannelCapacityResponse { - channel_capacity_msat: None, - }); - - let req = create_test_htlc_request(Some(scid), 10_000_000); - - let result = handler.handle(req).await.unwrap(); - assert_eq!(result.result, HtlcAcceptedResult::Fail); - assert_eq!( - result.failure_message.unwrap(), - UNKNOWN_NEXT_PEER.to_string() - ); - } - - #[tokio::test] - async fn test_htlc_fund_channel_fails() { - let fake = FakeCln::default(); - let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); - let peer_id = create_peer_id(); - let scid = ShortChannelId::from(123456789u64); - - let ds_entry = create_test_datastore_entry(peer_id, None); - *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); - - *fake.lsps2_getchannelcapacity_response.lock().unwrap() = - Some(Lsps2PolicyGetChannelCapacityResponse { - channel_capacity_msat: Some(50_000_000), - }); - - *fake.fund_channel_error.lock().unwrap() = Some(anyhow::anyhow!("insufficient funds")); - - let req = create_test_htlc_request(Some(scid), 10_000_000); - - let result = handler.handle(req).await.unwrap(); - assert_eq!(result.result, HtlcAcceptedResult::Fail); - assert_eq!( - result.failure_message.unwrap(), - UNKNOWN_NEXT_PEER.to_string() - ); - } - - #[tokio::test] - async fn test_htlc_successful_flow() { - let fake = FakeCln::default(); - let handler = HtlcAcceptedHookHandler { - api: fake.clone(), - htlc_minimum_msat: 1000, - backoff_listpeerchannels: Duration::from_millis(10), - }; - let peer_id = create_peer_id(); - let scid = ShortChannelId::from(123456789u64); - - let ds_entry = create_test_datastore_entry(peer_id, None); - *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); - - *fake.lsps2_getchannelcapacity_response.lock().unwrap() = - Some(Lsps2PolicyGetChannelCapacityResponse { - channel_capacity_msat: Some(50_000_000), - }); - - *fake.fund_channel_response.lock().unwrap() = - Some((*Sha256::from_bytes_ref(&[1u8; 32]), String::default())); - - let req = create_test_htlc_request(Some(scid), 10_000_000); - - let result = handler.handle(req).await.unwrap(); - assert_eq!(result.result, HtlcAcceptedResult::Continue); - - assert!(result.payload.is_some()); - assert!(result.extra_tlvs.is_some()); - assert!(result.forward_to.is_some()); - - // The payload should have the deducted amount - let payload_bytes = result.payload.unwrap(); - let payload_tlv = TlvStream::from_bytes(&payload_bytes).unwrap(); - - // Should contain forward amount. - assert!(payload_tlv.get(TLV_FORWARD_AMT).is_some()); - } - - #[tokio::test] - #[ignore] // We deactivate the mpp check on the experimental server for - // client side checks. - async fn test_htlc_mpp_not_implemented() { - let fake = FakeCln::default(); - let handler = HtlcAcceptedHookHandler::new(fake.clone(), 1000); - let peer_id = create_peer_id(); - let scid = ShortChannelId::from(123456789u64); - - // Create entry with expected_payment_size (MPP mode) - let mut ds_entry = create_test_datastore_entry(peer_id, None); - ds_entry.expected_payment_size = Some(Msat::from_msat(1000000)); - *fake.get_buy_request_response.lock().unwrap() = Some(ds_entry); - - let req = create_test_htlc_request(Some(scid), 10_000_000); - - let result = handler.handle(req).await.unwrap(); - assert_eq!(result.result, HtlcAcceptedResult::Fail); - assert_eq!( - result.failure_message.unwrap(), - UNKNOWN_NEXT_PEER.to_string() - ); - } -} diff --git a/plugins/lsps-plugin/src/core/lsps2/htlc.rs b/plugins/lsps-plugin/src/core/lsps2/htlc.rs deleted file mode 100644 index 6e39cc07cf51..000000000000 --- a/plugins/lsps-plugin/src/core/lsps2/htlc.rs +++ /dev/null @@ -1,802 +0,0 @@ -use crate::{ - core::{ - lsps2::provider::{DatastoreProvider, LightningProvider, Lsps2OfferProvider}, - tlv::{TlvStream, TLV_FORWARD_AMT}, - }, - proto::{ - lsps0::{Msat, ShortChannelId}, - lsps2::{ - compute_opening_fee, - failure_codes::{TEMPORARY_CHANNEL_FAILURE, UNKNOWN_NEXT_PEER}, - Lsps2PolicyGetChannelCapacityRequest, - }, - }, -}; -use bitcoin::hashes::sha256::Hash; -use chrono::Utc; -use std::time::Duration; -use thiserror::Error; - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub enum HtlcDecision { - NotOurs, - Forward { - payload: TlvStream, - forward_to: Hash, - extra_tlvs: TlvStream, - }, - - Reject { - reason: RejectReason, - }, -} - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub enum RejectReason { - OfferExpired { valid_until: chrono::DateTime }, - AmountBelowMinimum { minimum: Msat }, - AmountAboveMaximum { maximum: Msat }, - InsufficientForFee { fee: Msat }, - FeeOverflow, - PolicyDenied, - FundingFailed, - - // temporarily - MppNotSupported, -} - -impl RejectReason { - pub fn failure_code(&self) -> &'static str { - match self { - Self::OfferExpired { .. } => TEMPORARY_CHANNEL_FAILURE, - _ => UNKNOWN_NEXT_PEER, - } - } -} - -#[derive(Debug, Error)] -pub enum HtlcError { - #[error("failed to query channel capacity: {0}")] - CapacityQuery(#[source] anyhow::Error), - #[error("failed to fund channel: {0}")] - FundChannel(#[source] anyhow::Error), - #[error("channel ready check failed: {0}")] - ChannelReadyCheck(#[source] anyhow::Error), -} - -#[derive(Debug, Clone)] -pub struct Htlc { - pub amount_msat: Msat, - pub extra_tlvs: TlvStream, -} -impl Htlc { - pub fn new(amount_msat: Msat, tlvs: TlvStream) -> Self { - Self { - amount_msat, - extra_tlvs: tlvs, - } - } -} - -#[derive(Debug, Clone)] -pub struct Onion { - pub short_channel_id: ShortChannelId, - pub payload: TlvStream, -} - -pub struct HtlcAcceptedHookHandler { - api: A, - htlc_minimum_msat: u64, - backoff_listpeerchannels: Duration, -} - -impl HtlcAcceptedHookHandler { - pub fn new(api: A, htlc_minimum_msat: u64) -> Self { - Self { - api, - htlc_minimum_msat, - backoff_listpeerchannels: Duration::from_secs(10), - } - } -} -impl HtlcAcceptedHookHandler { - pub async fn handle(&self, htlc: &Htlc, onion: &Onion) -> Result { - // A) Is this SCID one that we care about? - let ds_rec = match self.api.get_buy_request(&onion.short_channel_id).await { - Ok(rec) => rec, - Err(_) => return Ok(HtlcDecision::NotOurs), - }; - - // Fixme: Check that we don't have a channel yet with the peer that we await to - // become READY to use. - // --- - - // Fixme: We only accept no-mpp for now, mpp and other flows will be added later on - // Fixme: We continue mpp for now to let the test mock handle the htlc, as we need - // to test the client implementation for mpp payments. - if ds_rec.expected_payment_size.is_some() { - return Ok(HtlcDecision::Reject { - reason: RejectReason::MppNotSupported, - }); - } - - // B) Is the fee option menu still valid? - if Utc::now() >= ds_rec.opening_fee_params.valid_until { - // Not valid anymore, remove from DS and fail HTLC. - let _ = self.api.del_buy_request(&onion.short_channel_id).await; - return Ok(HtlcDecision::Reject { - reason: RejectReason::OfferExpired { - valid_until: ds_rec.opening_fee_params.valid_until, - }, - }); - } - - // C) Is the amount in the boundaries of the fee menu? - if htlc.amount_msat.msat() < ds_rec.opening_fee_params.min_fee_msat.msat() { - return Ok(HtlcDecision::Reject { - reason: RejectReason::AmountBelowMinimum { - minimum: ds_rec.opening_fee_params.min_fee_msat, - }, - }); - } - - if htlc.amount_msat.msat() > ds_rec.opening_fee_params.max_payment_size_msat.msat() { - return Ok(HtlcDecision::Reject { - reason: RejectReason::AmountAboveMaximum { - maximum: ds_rec.opening_fee_params.max_payment_size_msat, - }, - }); - } - - // D) Check that the amount_msat covers the opening fee (only for non-mpp right now) - let opening_fee = match compute_opening_fee( - htlc.amount_msat.msat(), - ds_rec.opening_fee_params.min_fee_msat.msat(), - ds_rec.opening_fee_params.proportional.ppm() as u64, - ) { - Some(fee) if fee + self.htlc_minimum_msat < htlc.amount_msat.msat() => fee, - Some(fee) => { - return Ok(HtlcDecision::Reject { - reason: RejectReason::InsufficientForFee { - fee: Msat::from_msat(fee), - }, - }) - } - None => { - return Ok(HtlcDecision::Reject { - reason: RejectReason::FeeOverflow, - }) - } - }; - - // E) We made it, open a channel to the peer. - let ch_cap_req = Lsps2PolicyGetChannelCapacityRequest { - opening_fee_params: ds_rec.opening_fee_params, - init_payment_size: htlc.amount_msat, - scid: onion.short_channel_id, - }; - let ch_cap_res = self - .api - .get_channel_capacity(&ch_cap_req) - .await - .map_err(HtlcError::CapacityQuery)?; - - let cap = match ch_cap_res.channel_capacity_msat { - Some(c) => Msat::from_msat(c), - None => { - return Ok(HtlcDecision::Reject { - reason: RejectReason::PolicyDenied, - }) - } - }; - - // We take the policy-giver seriously, if the capacity is too low, we - // still try to open the channel. - // Fixme: We may check that the capacity is ge than the - // (amount_msat - opening fee) in the future. - // Fixme: Make this configurable, maybe return the whole request from - // the policy giver? - let (channel_id, _) = self - .api - .fund_jit_channel(&ds_rec.peer_id, &cap) - .await - .map_err(HtlcError::FundChannel)?; - - // F) Wait for the peer to send `channel_ready`. - // Fixme: Use event to check for channel ready, - // Fixme: Check for htlc timeout if peer refuses to send "ready". - // Fixme: handle unexpected channel states. - loop { - match self - .api - .is_channel_ready(&ds_rec.peer_id, &channel_id) - .await - { - Ok(true) => break, - Ok(false) => tokio::time::sleep(self.backoff_listpeerchannels).await, - Err(e) => return Err(HtlcError::ChannelReadyCheck(e)), - }; - } - - // G) We got a working channel, deduct fee and forward htlc. - let deducted_amt_msat = htlc.amount_msat.msat() - opening_fee; - let mut payload = onion.payload.clone(); - payload.set_tu64(TLV_FORWARD_AMT, deducted_amt_msat); - - let mut extra_tlvs = htlc.extra_tlvs.clone(); - extra_tlvs.set_u64(65537, opening_fee); - - Ok(HtlcDecision::Forward { - payload, - forward_to: channel_id, - extra_tlvs, - }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::core::tlv::TlvStream; - use crate::proto::lsps0::{Msat, Ppm, ShortChannelId}; - use crate::proto::lsps2::{ - DatastoreEntry, Lsps2PolicyGetChannelCapacityResponse, Lsps2PolicyGetInfoRequest, - Lsps2PolicyGetInfoResponse, OpeningFeeParams, Promise, - }; - use anyhow::{anyhow, Result as AnyResult}; - use async_trait::async_trait; - use bitcoin::hashes::{sha256::Hash as Sha256, Hash}; - use bitcoin::secp256k1::PublicKey; - use chrono::{TimeZone, Utc}; - use std::sync::atomic::{AtomicUsize, Ordering}; - use std::sync::{Arc, Mutex}; - use std::time::Duration; - use std::u64; - - fn test_peer_id() -> PublicKey { - "0279BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798" - .parse() - .unwrap() - } - - fn test_scid() -> ShortChannelId { - ShortChannelId::from(123456789u64) - } - - fn test_channel_id() -> Sha256 { - Sha256::from_byte_array([1u8; 32]) - } - - fn valid_opening_fee_params() -> OpeningFeeParams { - OpeningFeeParams { - min_fee_msat: Msat(2_000), - proportional: Ppm(10_000), // 1% - valid_until: Utc.with_ymd_and_hms(2100, 1, 1, 0, 0, 0).unwrap(), - min_lifetime: 1000, - max_client_to_self_delay: 2016, - min_payment_size_msat: Msat(1_000_000), - max_payment_size_msat: Msat(100_000_000), - promise: Promise::try_from("test").unwrap(), - } - } - - fn expired_opening_fee_params() -> OpeningFeeParams { - OpeningFeeParams { - valid_until: Utc.with_ymd_and_hms(2000, 1, 1, 0, 0, 0).unwrap(), - ..valid_opening_fee_params() - } - } - - fn test_datastore_entry(expected_payment_size: Option) -> DatastoreEntry { - DatastoreEntry { - peer_id: test_peer_id(), - opening_fee_params: valid_opening_fee_params(), - expected_payment_size, - } - } - - fn test_onion(scid: ShortChannelId, payload: TlvStream) -> Onion { - Onion { - short_channel_id: scid, - payload, - } - } - - fn test_htlc(amount_msat: u64, extra_tlvs: TlvStream) -> Htlc { - Htlc { - amount_msat: Msat::from_msat(amount_msat), - extra_tlvs, - } - } - - #[derive(Default, Clone)] - struct MockApi { - // Datastore - buy_request: Arc>>, - buy_request_error: Arc>, - del_called: Arc, - - // Policy - channel_capacity: Arc>>>, // Some(Some(cap)), Some(None) = denied, None = error - channel_capacity_error: Arc>, - - // Lightning - fund_result: Arc>>, - fund_error: Arc>, - channel_ready: Arc>, - channel_ready_checks: Arc, - } - - impl MockApi { - fn new() -> Self { - Self::default() - } - - fn with_buy_request(self, entry: DatastoreEntry) -> Self { - *self.buy_request.lock().unwrap() = Some(entry); - self - } - - fn with_no_buy_request(self) -> Self { - *self.buy_request_error.lock().unwrap() = true; - self - } - - fn with_channel_capacity(self, capacity_msat: u64) -> Self { - *self.channel_capacity.lock().unwrap() = Some(Some(capacity_msat)); - self - } - - fn with_channel_denied(self) -> Self { - *self.channel_capacity.lock().unwrap() = Some(None); - self - } - - fn with_channel_capacity_error(self) -> Self { - *self.channel_capacity_error.lock().unwrap() = true; - self - } - - fn with_fund_result(self, channel_id: Sha256, txid: &str) -> Self { - *self.fund_result.lock().unwrap() = Some((channel_id, txid.to_string())); - self - } - - fn with_fund_error(self) -> Self { - *self.fund_error.lock().unwrap() = true; - self - } - - fn with_channel_ready(self, ready: bool) -> Self { - *self.channel_ready.lock().unwrap() = ready; - self - } - - fn del_call_count(&self) -> usize { - self.del_called.load(Ordering::SeqCst) - } - - fn channel_ready_check_count(&self) -> usize { - self.channel_ready_checks.load(Ordering::SeqCst) - } - } - - #[async_trait] - impl DatastoreProvider for MockApi { - async fn store_buy_request( - &self, - _scid: &ShortChannelId, - _peer_id: &PublicKey, - _fee_params: &OpeningFeeParams, - _payment_size: &Option, - ) -> AnyResult { - unimplemented!("not needed for HTLC tests") - } - - async fn get_buy_request(&self, _scid: &ShortChannelId) -> AnyResult { - if *self.buy_request_error.lock().unwrap() { - return Err(anyhow!("not found")); - } - self.buy_request - .lock() - .unwrap() - .clone() - .ok_or_else(|| anyhow!("not found")) - } - - async fn del_buy_request(&self, _scid: &ShortChannelId) -> AnyResult<()> { - self.del_called.fetch_add(1, Ordering::SeqCst); - Ok(()) - } - } - - #[async_trait] - impl Lsps2OfferProvider for MockApi { - async fn get_offer( - &self, - _request: &Lsps2PolicyGetInfoRequest, - ) -> AnyResult { - unimplemented!("not needed for HTLC tests") - } - - async fn get_channel_capacity( - &self, - _params: &Lsps2PolicyGetChannelCapacityRequest, - ) -> AnyResult { - if *self.channel_capacity_error.lock().unwrap() { - return Err(anyhow!("capacity error")); - } - let cap = self - .channel_capacity - .lock() - .unwrap() - .ok_or_else(|| anyhow!("no capacity set"))?; - Ok(Lsps2PolicyGetChannelCapacityResponse { - channel_capacity_msat: cap, - }) - } - } - - #[async_trait] - impl LightningProvider for MockApi { - async fn fund_jit_channel( - &self, - _peer_id: &PublicKey, - _amount: &Msat, - ) -> AnyResult<(Sha256, String)> { - if *self.fund_error.lock().unwrap() { - return Err(anyhow!("fund error")); - } - self.fund_result - .lock() - .unwrap() - .clone() - .ok_or_else(|| anyhow!("no fund result set")) - } - - async fn is_channel_ready( - &self, - _peer_id: &PublicKey, - _channel_id: &Sha256, - ) -> AnyResult { - self.channel_ready_checks.fetch_add(1, Ordering::SeqCst); - Ok(*self.channel_ready.lock().unwrap()) - } - } - - fn handler(api: MockApi) -> HtlcAcceptedHookHandler { - HtlcAcceptedHookHandler { - api, - htlc_minimum_msat: 1_000, - backoff_listpeerchannels: Duration::from_millis(1), // Fast for tests - } - } - - #[tokio::test] - async fn continues_when_scid_not_found() { - let api = MockApi::new().with_no_buy_request(); - let h = handler(api); - - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(10_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - assert_eq!(result, HtlcDecision::NotOurs); - } - - #[tokio::test] - async fn continues_when_mpp_payment() { - let entry = test_datastore_entry(Some(Msat(50_000_000))); // MPP = has expected size - let api = MockApi::new().with_buy_request(entry); - let h = handler(api); - - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(10_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - assert_eq!( - result, - HtlcDecision::Reject { - reason: RejectReason::MppNotSupported - } - ); - } - - #[tokio::test] - async fn fails_when_offer_expired() { - let mut entry = test_datastore_entry(None); - entry.opening_fee_params = expired_opening_fee_params(); - let api = MockApi::new().with_buy_request(entry); - let h = handler(api.clone()); - - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(10_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - assert!(matches!( - result, - HtlcDecision::Reject { - reason: RejectReason::OfferExpired { .. } - } - )); - assert_eq!(api.del_call_count(), 1); // Should delete expired entry - } - - #[tokio::test] - async fn fails_when_amount_below_min_fee() { - let entry = test_datastore_entry(None); - let api = MockApi::new().with_buy_request(entry); - let h = handler(api); - - // min_fee_msat is 2_000 - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(1_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - assert!(matches!( - result, - HtlcDecision::Reject { - reason: RejectReason::AmountBelowMinimum { .. } - } - )); - } - - #[tokio::test] - async fn fails_when_amount_above_max() { - let entry = test_datastore_entry(None); - let api = MockApi::new().with_buy_request(entry); - let h = handler(api); - - // max_payment_size_msat is 100_000_000 - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(200_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - assert!(matches!( - result, - HtlcDecision::Reject { - reason: RejectReason::AmountAboveMaximum { .. } - } - )); - } - - #[tokio::test] - async fn fails_when_amount_doesnt_cover_fee_plus_minimum() { - let entry = test_datastore_entry(None); - let api = MockApi::new().with_buy_request(entry); - let h = handler(api); - - // min_fee = 2_000, htlc_minimum = 1_000 - // Amount must be > fee + htlc_minimum - // At 3_000: fee ~= 2_000 + (3_000 * 10_000 / 1_000_000) = 2_030 - // 2_030 + 1_000 = 3_030 > 3_000, so should fail - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(3_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - assert!(matches!( - result, - HtlcDecision::Reject { - reason: RejectReason::InsufficientForFee { .. } - } - )); - } - - #[tokio::test] - async fn fails_when_fee_computation_overflows() { - let mut entry = test_datastore_entry(None); - entry.opening_fee_params.min_fee_msat = Msat(u64::MAX / 2); - entry.opening_fee_params.proportional = Ppm(u32::MAX); - entry.opening_fee_params.min_payment_size_msat = Msat(1); - entry.opening_fee_params.max_payment_size_msat = Msat(u64::MAX); - - let api = MockApi::new().with_buy_request(entry); - let h = handler(api); - - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(u64::MAX / 2, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - assert!(matches!( - result, - HtlcDecision::Reject { - reason: RejectReason::FeeOverflow, - } - )); - } - - #[tokio::test] - async fn fails_when_channel_capacity_errors() { - let entry = test_datastore_entry(None); - let api = MockApi::new() - .with_buy_request(entry) - .with_channel_capacity_error(); - let h = handler(api); - - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(10_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.expect_err("should fail"); - - assert!(matches!(result, HtlcError::CapacityQuery(_))); - } - - #[tokio::test] - async fn fails_when_policy_denies_channel() { - let entry = test_datastore_entry(None); - let api = MockApi::new().with_buy_request(entry).with_channel_denied(); - let h = handler(api); - - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(10_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - assert!(matches!( - result, - HtlcDecision::Reject { - reason: RejectReason::PolicyDenied, - } - )); - } - - #[tokio::test] - async fn fails_when_fund_channel_errors() { - let entry = test_datastore_entry(None); - let api = MockApi::new() - .with_buy_request(entry) - .with_channel_capacity(50_000_000) - .with_fund_error(); - let h = handler(api); - - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(10_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.expect_err("should fail"); - - assert!(matches!(result, HtlcError::FundChannel(_))); - } - - #[tokio::test] - async fn success_flow_continues_with_modified_payload() { - let entry = test_datastore_entry(None); - let api = MockApi::new() - .with_buy_request(entry) - .with_channel_capacity(50_000_000) - .with_fund_result(test_channel_id(), "txid123") - .with_channel_ready(true); - let h = handler(api.clone()); - - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(10_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - let HtlcDecision::Forward { - payload, - forward_to, - extra_tlvs, - } = result - else { - panic!("expected forward, got {:?}", result) - }; - - assert_eq!(forward_to, test_channel_id()); - assert!(!payload.0.is_empty()); - assert!(!extra_tlvs.0.is_empty()); - } - - #[tokio::test] - async fn polls_until_channel_ready() { - let entry = test_datastore_entry(None); - let api = MockApi::new() - .with_buy_request(entry) - .with_channel_capacity(50_000_000) - .with_fund_result(test_channel_id(), "txid123") - .with_channel_ready(false); - - let h = handler(api.clone()); - - // Spawn handler, will block on channel ready - let handle = tokio::spawn(async move { - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(10_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - result - }); - - // Let it poll a few times - tokio::time::sleep(Duration::from_millis(10)).await; - assert!(api.channel_ready_check_count() > 1); - - // Now make channel ready - *api.channel_ready.lock().unwrap() = true; - - let result = handle.await.unwrap(); - assert!(matches!(result, HtlcDecision::Forward { .. })); - } - - #[tokio::test] - async fn deducts_fee_from_forward_amount() { - let entry = test_datastore_entry(None); - let api = MockApi::new() - .with_buy_request(entry) - .with_channel_capacity(50_000_000) - .with_fund_result(test_channel_id(), "txid123") - .with_channel_ready(true); - let h = handler(api); - - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(10_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - let HtlcDecision::Forward { payload, .. } = result else { - panic!("expected forward, got {:?}", result) - }; - - // Verify payload contains deducted amount - // fee = max(min_fee, amount * proportional / 1_000_000) - // fee = max(2_000, 10_000_000 * 10_000 / 1_000_000) = max(2_000, 100_000) = 100_000 - // deducted = 10_000_000 - 100_000 = 9_900_000 - let forward_amt = payload.get_tu64(TLV_FORWARD_AMT).unwrap(); - assert_eq!(forward_amt, Some(9_900_000)); - } - - #[tokio::test] - async fn extra_tlvs_contain_opening_fee() { - let entry = test_datastore_entry(None); - let api = MockApi::new() - .with_buy_request(entry) - .with_channel_capacity(50_000_000) - .with_fund_result(test_channel_id(), "txid123") - .with_channel_ready(true); - let h = handler(api); - - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(10_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - let HtlcDecision::Forward { extra_tlvs, .. } = result else { - panic!("expected forward, got {:?}", result) - }; - - // Opening fee should be in TLV 65537 - let opening_fee = extra_tlvs.get_u64(65537).unwrap(); - assert_eq!(opening_fee, Some(100_000)); // Same fee calculation as above - } - - #[tokio::test] - async fn handles_minimum_valid_amount() { - let entry = test_datastore_entry(None); - let api = MockApi::new() - .with_buy_request(entry) - .with_channel_capacity(50_000_000) - .with_fund_result(test_channel_id(), "txid123") - .with_channel_ready(true); - let h = handler(api); - - // Just enough to cover fee + htlc_minimum - // fee at 1_000_000 = max(2_000, 1_000_000 * 10_000 / 1_000_000) = max(2_000, 10_000) = 10_000 - // Need: fee + htlc_minimum < amount - // 10_000 + 1_000 = 11_000 < 1_000_000 ✓ - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(1_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - assert!(matches!(result, HtlcDecision::Forward { .. })); - } - - #[tokio::test] - async fn handles_maximum_valid_amount() { - let entry = test_datastore_entry(None); - let api = MockApi::new() - .with_buy_request(entry) - .with_channel_capacity(200_000_000) - .with_fund_result(test_channel_id(), "txid123") - .with_channel_ready(true); - let h = handler(api); - - // max_payment_size_msat is 100_000_000 - let onion = test_onion(test_scid(), TlvStream::default()); - let htlc = test_htlc(100_000_000, TlvStream::default()); - let result = h.handle(&htlc, &onion).await.unwrap(); - - assert!(matches!(result, HtlcDecision::Forward { .. })); - } -} diff --git a/plugins/lsps-plugin/src/core/lsps2/manager.rs b/plugins/lsps-plugin/src/core/lsps2/manager.rs new file mode 100644 index 000000000000..6ef427405479 --- /dev/null +++ b/plugins/lsps-plugin/src/core/lsps2/manager.rs @@ -0,0 +1,563 @@ +use super::actor::{ActionExecutor, ActorInboxHandle, HtlcResponse}; +use super::provider::DatastoreProvider; +use super::session::{PaymentPart, Session}; +use crate::core::lsps2::actor::SessionActor; +use crate::proto::lsps0::ShortChannelId; +pub use bitcoin::hashes::sha256::Hash as PaymentHash; +use log::debug; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::Mutex; + +#[derive(Debug, thiserror::Error)] +pub enum ManagerError { + #[error("session terminated")] + SessionTerminated, + #[error("datastore lookup failed: {0}")] + DatastoreLookup(#[source] anyhow::Error), +} + +pub struct SessionConfig { + pub max_parts: usize, + pub collect_timeout_secs: u64, +} + +impl Default for SessionConfig { + fn default() -> Self { + Self { + max_parts: 30, // Core-Lightning default. + collect_timeout_secs: 90, // Blip52 default. + } + } +} + +pub struct SessionManager { + sessions: Mutex>, + datastore: Arc, + executor: Arc, + config: SessionConfig, +} + +impl + SessionManager +{ + pub fn new(datastore: Arc, executor: Arc, config: SessionConfig) -> Self { + Self { + sessions: Mutex::new(HashMap::new()), + datastore, + executor, + config, + } + } + + pub async fn on_part( + &self, + payment_hash: PaymentHash, + scid: ShortChannelId, + part: PaymentPart, + ) -> Result { + let handle = { + let mut sessions = self.sessions.lock().await; + if let Some(handle) = sessions.get(&payment_hash) { + handle.clone() + } else { + let handle = self.create_session(&scid).await?; + sessions.insert(payment_hash, handle.clone()); + handle + } + }; + + match handle.add_part(part).await { + Ok(resp) => Ok(resp), + Err(_) => { + self.sessions.lock().await.remove(&payment_hash); + Err(ManagerError::SessionTerminated) + } + } + } + + pub async fn on_payment_settled( + &self, + payment_hash: PaymentHash, + preimage: Option, + ) -> Result<(), ManagerError> { + let handle = { + let sessions = self.sessions.lock().await; + match sessions.get(&payment_hash) { + Some(handle) => handle.clone(), + None => { + debug!("on_payment_settled: no session for {payment_hash}"); + return Ok(()); + } + } + }; + + match handle.payment_settled(preimage).await { + Ok(()) => Ok(()), + Err(_) => { + self.sessions.lock().await.remove(&payment_hash); + Err(ManagerError::SessionTerminated) + } + } + } + + pub async fn on_payment_failed( + &self, + payment_hash: PaymentHash, + ) -> Result<(), ManagerError> { + let handle = { + let sessions = self.sessions.lock().await; + match sessions.get(&payment_hash) { + Some(handle) => handle.clone(), + None => { + debug!("on_payment_failed: no session for {payment_hash}"); + return Ok(()); + } + } + }; + + match handle.payment_failed().await { + Ok(()) => Ok(()), + Err(_) => { + self.sessions.lock().await.remove(&payment_hash); + Err(ManagerError::SessionTerminated) + } + } + } + + pub async fn on_new_block(&self, height: u32) { + let handles: Vec<(PaymentHash, ActorInboxHandle)> = { + let sessions = self.sessions.lock().await; + sessions.iter().map(|(k, v)| (*k, v.clone())).collect() + }; + + let mut dead = Vec::new(); + for (hash, handle) in handles { + if handle.new_block(height).await.is_err() { + dead.push(hash); + } + } + + if !dead.is_empty() { + let mut sessions = self.sessions.lock().await; + for hash in dead { + sessions.remove(&hash); + } + } + } + + async fn create_session( + &self, + scid: &ShortChannelId, + ) -> Result { + let entry = self + .datastore + .get_buy_request(scid) + .await + .map_err(ManagerError::DatastoreLookup)?; + + let session = Session::new( + self.config.max_parts, + entry.opening_fee_params, + entry.expected_payment_size, + entry.channel_capacity_msat, + entry.peer_id.to_string(), + ); + + Ok(SessionActor::spawn_session_actor( + session, + self.executor.clone(), + entry.peer_id.to_string(), + self.config.collect_timeout_secs, + *scid, + self.datastore.clone(), + )) + } + + #[cfg(test)] + async fn session_count(&self) -> usize { + self.sessions.lock().await.len() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::proto::lsps0::{Msat, Ppm}; + use crate::proto::lsps2::{DatastoreEntry, OpeningFeeParams, Promise, SessionOutcome}; + use async_trait::async_trait; + use bitcoin::hashes::Hash; + use chrono::{Duration as ChronoDuration, Utc}; + use std::time::Duration; + + fn test_payment_hash(byte: u8) -> PaymentHash { + PaymentHash::from_byte_array([byte; 32]) + } + + fn test_scid() -> ShortChannelId { + ShortChannelId::from(100u64 << 40 | 1u64 << 16) + } + + fn test_scid_2() -> ShortChannelId { + ShortChannelId::from(200u64 << 40 | 2u64 << 16) + } + + fn unknown_scid() -> ShortChannelId { + ShortChannelId::from(999u64 << 40 | 9u64 << 16 | 9) + } + + fn test_peer_id() -> cln_rpc::primitives::PublicKey { + serde_json::from_value(serde_json::json!( + "0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798" + )) + .unwrap() + } + + fn opening_fee_params(min_fee_msat: u64) -> OpeningFeeParams { + OpeningFeeParams { + min_fee_msat: Msat::from_msat(min_fee_msat), + proportional: Ppm::from_ppm(1_000), + valid_until: Utc::now() + ChronoDuration::hours(1), + min_lifetime: 144, + max_client_to_self_delay: 2016, + min_payment_size_msat: Msat::from_msat(1), + max_payment_size_msat: Msat::from_msat(u64::MAX), + promise: Promise("test-promise".to_owned()), + } + } + + fn test_datastore_entry() -> DatastoreEntry { + DatastoreEntry { + peer_id: test_peer_id(), + opening_fee_params: opening_fee_params(1), + expected_payment_size: Some(Msat::from_msat(1_000)), + channel_capacity_msat: Msat::from_msat(100_000_000), + created_at: Utc::now(), + channel_id: None, + funding_psbt: None, + funding_txid: None, + preimage: None, + } + } + + fn part(htlc_id: u64, amount_msat: u64) -> PaymentPart { + PaymentPart { + htlc_id, + amount_msat: Msat::from_msat(amount_msat), + cltv_expiry: 100, + } + } + + struct MockDatastore { + entries: HashMap, + } + + impl MockDatastore { + fn new() -> Self { + let mut entries = HashMap::new(); + entries.insert(test_scid().to_string(), test_datastore_entry()); + entries.insert(test_scid_2().to_string(), test_datastore_entry()); + Self { entries } + } + } + + #[async_trait] + impl DatastoreProvider for MockDatastore { + async fn store_buy_request( + &self, + _scid: &ShortChannelId, + _peer_id: &bitcoin::secp256k1::PublicKey, + _offer: &OpeningFeeParams, + _expected_payment_size: &Option, + _channel_capacity_msat: &Msat, + ) -> anyhow::Result { + Ok(true) + } + + async fn get_buy_request(&self, scid: &ShortChannelId) -> anyhow::Result { + self.entries + .get(&scid.to_string()) + .cloned() + .ok_or_else(|| anyhow::anyhow!("not found: {scid}")) + } + + async fn del_buy_request(&self, _scid: &ShortChannelId) -> anyhow::Result<()> { + Ok(()) + } + + async fn finalize_session( + &self, + _scid: &ShortChannelId, + _outcome: SessionOutcome, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn update_session_funding( + &self, + _scid: &ShortChannelId, + _channel_id: &str, + _funding_psbt: &str, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn update_session_funding_txid( + &self, + _scid: &ShortChannelId, + _funding_txid: &str, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn update_session_preimage( + &self, + _scid: &ShortChannelId, + _preimage: &str, + ) -> anyhow::Result<()> { + Ok(()) + } + } + + struct MockExecutor { + fund_succeeds: bool, + } + + #[async_trait] + impl ActionExecutor for MockExecutor { + async fn fund_channel( + &self, + _peer_id: String, + _channel_capacity_msat: Msat, + _opening_fee_params: OpeningFeeParams, + ) -> anyhow::Result<(String, String)> { + if self.fund_succeeds { + Ok(("channel-id-1".to_string(), "psbt-1".to_string())) + } else { + Err(anyhow::anyhow!("fund error")) + } + } + + async fn broadcast_tx( + &self, + _channel_id: String, + _funding_psbt: String, + ) -> anyhow::Result { + Ok("mock-txid".to_string()) + } + + async fn abandon_session( + &self, + _channel_id: String, + _funding_psbt: String, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn disconnect(&self, _peer_id: String) -> anyhow::Result<()> { + Ok(()) + } + + async fn is_channel_alive(&self, _channel_id: &str) -> anyhow::Result { + Ok(true) + } + } + + fn test_manager(fund_succeeds: bool) -> Arc> { + Arc::new(SessionManager::new( + Arc::new(MockDatastore::new()), + Arc::new(MockExecutor { fund_succeeds }), + SessionConfig { + max_parts: 3, + ..SessionConfig::default() + }, + )) + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn first_part_creates_session() { + let mgr = test_manager(true); + + let resp = mgr + .on_part(test_payment_hash(1), test_scid(), part(1, 1_000)) + .await + .unwrap(); + + assert!(matches!(resp, HtlcResponse::Forward { .. })); + assert_eq!(mgr.session_count().await, 1); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn second_part_routes_to_existing() { + let mgr = test_manager(true); + let hash = test_payment_hash(1); + + // First part reaches threshold (expected=1000) and gets Forward. + let resp1 = mgr + .on_part(hash, test_scid(), part(1, 1_000)) + .await + .unwrap(); + assert!(matches!(resp1, HtlcResponse::Forward { .. })); + + // Session is now in AwaitingSettlement. Second part is forwarded immediately. + let resp2 = mgr.on_part(hash, test_scid(), part(2, 500)).await.unwrap(); + match resp2 { + HtlcResponse::Forward { fee_msat, .. } => { + assert_eq!(fee_msat, 0, "late-arriving part should have zero fee"); + } + other => panic!("expected Forward, got {other:?}"), + } + + assert_eq!(mgr.session_count().await, 1); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn different_hashes_create_separate_sessions() { + let mgr = test_manager(true); + + let r1 = mgr + .on_part(test_payment_hash(1), test_scid(), part(1, 1_000)) + .await + .unwrap(); + let r2 = mgr + .on_part(test_payment_hash(2), test_scid_2(), part(2, 1_000)) + .await + .unwrap(); + + assert!(matches!(r1, HtlcResponse::Forward { .. })); + assert!(matches!(r2, HtlcResponse::Forward { .. })); + assert_eq!(mgr.session_count().await, 2); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn terminated_session_cleaned_up() { + let mgr = test_manager(true); + let hash = test_payment_hash(1); + + // First on_part with partial amount — won't reach threshold, blocks. + let mgr2 = mgr.clone(); + let h1 = tokio::spawn(async move { mgr2.on_part(hash, test_scid(), part(1, 500)).await }); + + // Advance past 90s collect timeout. + tokio::time::sleep(Duration::from_secs(91)).await; + + // First part should have received Fail from timeout. + let resp = h1.await.unwrap().unwrap(); + assert!(matches!(resp, HtlcResponse::Fail { .. })); + + // Stale entry still in the map. + assert_eq!(mgr.session_count().await, 1); + + // Next on_part detects dead session and cleans up. + let err = mgr + .on_part(hash, test_scid(), part(2, 500)) + .await + .unwrap_err(); + assert!(matches!(err, ManagerError::SessionTerminated { .. })); + assert_eq!(mgr.session_count().await, 0); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn datastore_lookup_failure() { + let mgr = test_manager(true); + + let err = mgr + .on_part(test_payment_hash(1), unknown_scid(), part(1, 1_000)) + .await + .unwrap_err(); + + assert!(matches!(err, ManagerError::DatastoreLookup { .. })); + assert_eq!(mgr.session_count().await, 0); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn payment_settled_unknown_hash_is_ok() { + let mgr = test_manager(true); + let result = mgr.on_payment_settled(test_payment_hash(99), None).await; + assert!(result.is_ok()); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn payment_settled_active_session() { + let mgr = test_manager(true); + let hash = test_payment_hash(1); + + // Create session and forward payment. + let resp = mgr + .on_part(hash, test_scid(), part(1, 1_000)) + .await + .unwrap(); + assert!(matches!(resp, HtlcResponse::Forward { .. })); + + // Settle payment — session is in AwaitingSettlement. + let result = mgr.on_payment_settled(hash, None).await; + assert!(result.is_ok()); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn payment_settled_stale_session_cleaned_up() { + let mgr = test_manager(true); + let hash = test_payment_hash(1); + + // Create a session with a partial amount — won't reach threshold. + let mgr2 = mgr.clone(); + let h1 = tokio::spawn(async move { mgr2.on_part(hash, test_scid(), part(1, 500)).await }); + + // Advance past 90s collect timeout → actor dies. + tokio::time::sleep(Duration::from_secs(91)).await; + let resp = h1.await.unwrap().unwrap(); + assert!(matches!(resp, HtlcResponse::Fail { .. })); + + // Stale entry remains. + assert_eq!(mgr.session_count().await, 1); + + // on_payment_settled hits dead handle → removes entry. + let err = mgr.on_payment_settled(hash, None).await.unwrap_err(); + assert!(matches!(err, ManagerError::SessionTerminated { .. })); + assert_eq!(mgr.session_count().await, 0); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn payment_failed_unknown_hash_is_ok() { + let mgr = test_manager(true); + let result = mgr.on_payment_failed(test_payment_hash(99)).await; + assert!(result.is_ok()); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn payment_failed_active_session() { + let mgr = test_manager(true); + let hash = test_payment_hash(1); + + // Create session and forward payment. + let resp = mgr + .on_part(hash, test_scid(), part(1, 1_000)) + .await + .unwrap(); + assert!(matches!(resp, HtlcResponse::Forward { .. })); + + // Fail payment — session is in AwaitingSettlement. + let result = mgr.on_payment_failed(hash).await; + assert!(result.is_ok()); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn concurrent_first_parts_same_hash() { + let mgr = test_manager(true); + let hash = test_payment_hash(1); + + // Two concurrent on_part calls for the same hash. + // expected_payment_size=1000, so two 500-msat parts reach threshold together. + let mgr2 = mgr.clone(); + let h1 = tokio::spawn(async move { mgr2.on_part(hash, test_scid(), part(1, 500)).await }); + let mgr3 = mgr.clone(); + let h2 = tokio::spawn(async move { mgr3.on_part(hash, test_scid(), part(2, 500)).await }); + + let r1 = h1.await.unwrap().unwrap(); + let r2 = h2.await.unwrap().unwrap(); + + assert!(matches!(r1, HtlcResponse::Forward { .. })); + assert!(matches!(r2, HtlcResponse::Forward { .. })); + assert_eq!(mgr.session_count().await, 1); + } +} diff --git a/plugins/lsps-plugin/src/core/lsps2/mod.rs b/plugins/lsps-plugin/src/core/lsps2/mod.rs index 18bf1cb51ce1..22cc81b36983 100644 --- a/plugins/lsps-plugin/src/core/lsps2/mod.rs +++ b/plugins/lsps-plugin/src/core/lsps2/mod.rs @@ -1,3 +1,5 @@ -pub mod htlc; +pub mod actor; +pub mod manager; pub mod provider; pub mod service; +pub mod session; diff --git a/plugins/lsps-plugin/src/core/lsps2/provider.rs b/plugins/lsps-plugin/src/core/lsps2/provider.rs index 6466630a4748..2d3b5c3ce2d2 100644 --- a/plugins/lsps-plugin/src/core/lsps2/provider.rs +++ b/plugins/lsps-plugin/src/core/lsps2/provider.rs @@ -1,14 +1,12 @@ use anyhow::Result; use async_trait::async_trait; -use bitcoin::hashes::sha256::Hash; use bitcoin::secp256k1::PublicKey; use crate::proto::{ lsps0::{Msat, ShortChannelId}, lsps2::{ - DatastoreEntry, Lsps2PolicyGetChannelCapacityRequest, - Lsps2PolicyGetChannelCapacityResponse, Lsps2PolicyGetInfoRequest, - Lsps2PolicyGetInfoResponse, OpeningFeeParams, + DatastoreEntry, Lsps2PolicyBuyRequest, Lsps2PolicyBuyResponse, Lsps2PolicyGetInfoRequest, + Lsps2PolicyGetInfoResponse, OpeningFeeParams, SessionOutcome, }, }; @@ -27,27 +25,36 @@ pub trait DatastoreProvider: Send + Sync { peer_id: &PublicKey, offer: &OpeningFeeParams, expected_payment_size: &Option, + channel_capacity_msat: &Msat, ) -> Result; async fn get_buy_request(&self, scid: &ShortChannelId) -> Result; async fn del_buy_request(&self, scid: &ShortChannelId) -> Result<()>; -} -#[async_trait] -pub trait LightningProvider: Send + Sync { - async fn fund_jit_channel(&self, peer_id: &PublicKey, amount: &Msat) -> Result<(Hash, String)>; - async fn is_channel_ready(&self, peer_id: &PublicKey, channel_id: &Hash) -> Result; + async fn finalize_session(&self, scid: &ShortChannelId, outcome: SessionOutcome) -> Result<()>; + + async fn update_session_funding( + &self, + scid: &ShortChannelId, + channel_id: &str, + funding_psbt: &str, + ) -> Result<()>; + + async fn update_session_funding_txid( + &self, + scid: &ShortChannelId, + funding_txid: &str, + ) -> Result<()>; + + async fn update_session_preimage(&self, scid: &ShortChannelId, preimage: &str) -> Result<()>; } #[async_trait] -pub trait Lsps2OfferProvider: Send + Sync { - async fn get_offer( +pub trait Lsps2PolicyProvider: Send + Sync { + async fn get_info( &self, request: &Lsps2PolicyGetInfoRequest, ) -> Result; - async fn get_channel_capacity( - &self, - params: &Lsps2PolicyGetChannelCapacityRequest, - ) -> Result; + async fn buy(&self, request: &Lsps2PolicyBuyRequest) -> Result; } diff --git a/plugins/lsps-plugin/src/core/lsps2/service.rs b/plugins/lsps-plugin/src/core/lsps2/service.rs index a3ab32406e71..829a6e7ae7ec 100644 --- a/plugins/lsps-plugin/src/core/lsps2/service.rs +++ b/plugins/lsps-plugin/src/core/lsps2/service.rs @@ -1,6 +1,6 @@ use crate::{ core::{ - lsps2::provider::{BlockheightProvider, DatastoreProvider, Lsps2OfferProvider}, + lsps2::provider::{BlockheightProvider, DatastoreProvider, Lsps2PolicyProvider}, router::JsonRpcRouterBuilder, server::LspsProtocol, }, @@ -9,7 +9,8 @@ use crate::{ lsps0::{LSPS0RpcErrorExt as _, ShortChannelId}, lsps2::{ Lsps2BuyRequest, Lsps2BuyResponse, Lsps2GetInfoRequest, Lsps2GetInfoResponse, - Lsps2PolicyGetInfoRequest, OpeningFeeParams, ShortChannelIdJITExt, + Lsps2PolicyBuyRequest, Lsps2PolicyGetInfoRequest, OpeningFeeParams, + ShortChannelIdJITExt, }, }, register_handler, @@ -63,7 +64,7 @@ impl Lsps2ServiceHandler { } #[async_trait] -impl Lsps2Handler +impl Lsps2Handler for Lsps2ServiceHandler { async fn handle_get_info( @@ -72,7 +73,7 @@ impl ) -> std::result::Result { let res_data = self .api - .get_offer(&Lsps2PolicyGetInfoRequest { + .get_info(&Lsps2PolicyGetInfoRequest { token: request.token.clone(), }) .await @@ -116,9 +117,28 @@ impl // already handed out -> Check datastore entries. let jit_scid = ShortChannelId::generate_jit(blockheight, 12); // Approximately 2 hours in the future. + let ch_cap_res = self + .api + .buy(&Lsps2PolicyBuyRequest { + opening_fee_params: fee_params.clone(), + payment_size_msat: request.payment_size_msat, + }) + .await + .map_err(|_| RpcError::internal_error("internal error"))?; + + let channel_capacity_msat = ch_cap_res + .channel_capacity_msat + .ok_or_else(|| RpcError::internal_error("channel capacity denied by policy"))?; + let ok = self .api - .store_buy_request(&jit_scid, &peer_id, &fee_params, &request.payment_size_msat) + .store_buy_request( + &jit_scid, + &peer_id, + &fee_params, + &request.payment_size_msat, + &channel_capacity_msat, + ) .await .map_err(|_| RpcError::internal_error("internal error"))?; @@ -142,9 +162,9 @@ mod tests { use super::*; use crate::proto::lsps0::{Msat, Ppm}; use crate::proto::lsps2::{ - DatastoreEntry, Lsps2PolicyGetChannelCapacityRequest, - Lsps2PolicyGetChannelCapacityResponse, Lsps2PolicyGetInfoResponse, OpeningFeeParams, - PolicyOpeningFeeParams, Promise, + DatastoreEntry, Lsps2PolicyBuyRequest, Lsps2PolicyBuyResponse, + Lsps2PolicyGetInfoResponse, OpeningFeeParams, PolicyOpeningFeeParams, Promise, + SessionOutcome, }; use anyhow::{anyhow, Result as AnyResult}; use chrono::{TimeZone, Utc}; @@ -188,11 +208,13 @@ mod tests { offer_response: Arc>>, blockheight: Arc>>, store_result: Arc>>, + buy_response: Arc>>>, // Errors offer_error: Arc>, blockheight_error: Arc>, store_error: Arc>, + buy_error: Arc>, // Capture calls stored_requests: Arc>>, @@ -254,14 +276,24 @@ mod tests { self } + fn with_buy_capacity(self, capacity_msat: u64) -> Self { + *self.buy_response.lock().unwrap() = Some(Some(Msat::from_msat(capacity_msat))); + self + } + + fn with_buy_error(self) -> Self { + *self.buy_error.lock().unwrap() = true; + self + } + fn stored_requests(&self) -> Vec { self.stored_requests.lock().unwrap().clone() } } #[async_trait] - impl Lsps2OfferProvider for MockApi { - async fn get_offer( + impl Lsps2PolicyProvider for MockApi { + async fn get_info( &self, _request: &Lsps2PolicyGetInfoRequest, ) -> AnyResult { @@ -275,11 +307,21 @@ mod tests { .ok_or_else(|| anyhow!("no offer response set")) } - async fn get_channel_capacity( + async fn buy( &self, - _params: &Lsps2PolicyGetChannelCapacityRequest, - ) -> AnyResult { - unimplemented!("not needed for service tests") + _request: &Lsps2PolicyBuyRequest, + ) -> AnyResult { + if *self.buy_error.lock().unwrap() { + return Err(anyhow!("buy error")); + } + let cap = self + .buy_response + .lock() + .unwrap() + .ok_or_else(|| anyhow!("no buy response set"))?; + Ok(Lsps2PolicyBuyResponse { + channel_capacity_msat: cap, + }) } } @@ -304,6 +346,7 @@ mod tests { peer_id: &PublicKey, _fee_params: &OpeningFeeParams, payment_size: &Option, + _channel_capacity_msat: &Msat, ) -> AnyResult { if *self.store_error.lock().unwrap() { return Err(anyhow!("store error")); @@ -324,6 +367,39 @@ mod tests { async fn del_buy_request(&self, _scid: &ShortChannelId) -> AnyResult<()> { unimplemented!("not needed for service tests") } + + async fn finalize_session( + &self, + _scid: &ShortChannelId, + _outcome: SessionOutcome, + ) -> AnyResult<()> { + unimplemented!("not needed for service tests") + } + + async fn update_session_funding( + &self, + _scid: &ShortChannelId, + _channel_id: &str, + _funding_psbt: &str, + ) -> AnyResult<()> { + unimplemented!("not needed for service tests") + } + + async fn update_session_funding_txid( + &self, + _scid: &ShortChannelId, + _funding_txid: &str, + ) -> AnyResult<()> { + unimplemented!("not needed for service tests") + } + + async fn update_session_preimage( + &self, + _scid: &ShortChannelId, + _preimage: &str, + ) -> AnyResult<()> { + unimplemented!("not needed for service tests") + } } fn handler(api: MockApi) -> Lsps2ServiceHandler { @@ -408,7 +484,8 @@ mod tests { async fn buy_success_with_payment_size() { let api = MockApi::new() .with_blockheight(800_000) - .with_store_result(true); + .with_store_result(true) + .with_buy_capacity(100_000_000); let h = handler(api.clone()); let request = Lsps2BuyRequest { @@ -434,7 +511,8 @@ mod tests { async fn buy_success_without_payment_size() { let api = MockApi::new() .with_blockheight(800_000) - .with_store_result(true); + .with_store_result(true) + .with_buy_capacity(100_000_000); let h = handler(api.clone()); let request = Lsps2BuyRequest { @@ -537,7 +615,7 @@ mod tests { #[tokio::test] async fn buy_handles_blockheight_error() { - let api = MockApi::new().with_blockheight_error(); + let api = MockApi::new().with_blockheight_error().with_buy_capacity(100_000_000); let h = handler(api); let request = Lsps2BuyRequest { @@ -553,7 +631,7 @@ mod tests { #[tokio::test] async fn buy_handles_store_error() { - let api = MockApi::new().with_blockheight(800_000).with_store_error(); + let api = MockApi::new().with_blockheight(800_000).with_store_error().with_buy_capacity(100_000_000); let h = handler(api); let request = Lsps2BuyRequest { @@ -571,7 +649,8 @@ mod tests { async fn buy_handles_store_returns_false() { let api = MockApi::new() .with_blockheight(800_000) - .with_store_result(false); + .with_store_result(false) + .with_buy_capacity(100_000_000); let h = handler(api); let request = Lsps2BuyRequest { @@ -589,7 +668,8 @@ mod tests { async fn buy_generates_unique_scids() { let api = MockApi::new() .with_blockheight(800_000) - .with_store_result(true); + .with_store_result(true) + .with_buy_capacity(100_000_000); let h = handler(api); let request = Lsps2BuyRequest { diff --git a/plugins/lsps-plugin/src/core/lsps2/session.rs b/plugins/lsps-plugin/src/core/lsps2/session.rs index 298448548a8e..3d59012fc939 100644 --- a/plugins/lsps-plugin/src/core/lsps2/session.rs +++ b/plugins/lsps-plugin/src/core/lsps2/session.rs @@ -11,8 +11,6 @@ use crate::proto::{ #[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)] pub enum Error { - #[error("variable amount payments are not supported")] - UnimplementedVarAmount, #[error("opening fee computation overflow")] FeeOverflow, #[error("invalid state transition")] @@ -267,10 +265,6 @@ impl Session { // Collecting transitions. // (SessionState::Collecting { parts }, SessionInput::AddPart { part }) => { - if self.payment_size_msat.is_none() { - return Err(Error::UnimplementedVarAmount); - } - parts.push(part.clone()); let n_parts = parts.len(); let parts_sum = parts.iter().map(|p| p.amount_msat).sum(); @@ -281,24 +275,46 @@ impl Session { parts_sum, }]; - // Fail early if we have too many parts. - if n_parts > self.max_parts { - self.state = SessionState::Failed; - events.push(SessionEvent::TooManyParts { n_parts }); - events.push(SessionEvent::SessionFailed); - return Ok(ApplyResult { - actions: vec![ - SessionAction::FailHtlcs { - failure_code: UNKNOWN_NEXT_PEER, - }, - SessionAction::FailSession, - ], - events, - }); - } + // Variable-amount (None): first HTLC triggers immediately, second fails. + // Fixed-amount (Some): accumulate until threshold, fail if too many parts. + let threshold_reached = match self.payment_size_msat { + None => { + if n_parts > 1 { + self.state = SessionState::Failed; + events.push(SessionEvent::TooManyParts { n_parts }); + events.push(SessionEvent::SessionFailed); + return Ok(ApplyResult { + actions: vec![ + SessionAction::FailHtlcs { + failure_code: UNKNOWN_NEXT_PEER, + }, + SessionAction::FailSession, + ], + events, + }); + } + true + } + Some(_) => { + if n_parts > self.max_parts { + self.state = SessionState::Failed; + events.push(SessionEvent::TooManyParts { n_parts }); + events.push(SessionEvent::SessionFailed); + return Ok(ApplyResult { + actions: vec![ + SessionAction::FailHtlcs { + failure_code: UNKNOWN_NEXT_PEER, + }, + SessionAction::FailSession, + ], + events, + }); + } + parts_sum >= self.payment_size_msat.unwrap() + } + }; - let expected_msat = self.payment_size_msat.unwrap_or_else(|| Msat(0)); // We checked that it isn't None - if parts_sum >= expected_msat { + if threshold_reached { let opening_fee_msat = compute_opening_fee( parts_sum.msat(), self.opening_fee_params.min_fee_msat.msat(), @@ -1118,16 +1134,70 @@ mod tests { } #[test] - fn collecting_payment_size_none_errors_without_mutating_state() { + fn collecting_var_amount_single_htlc_triggers_funding() { let mut s = session(3, None, 1); - let err = s + let res = s .apply(SessionInput::AddPart { - part: part(1, 1_000), + part: part(1, 10_000_000), }) - .unwrap_err(); + .unwrap(); - assert_eq!(err, Error::UnimplementedVarAmount); - assert_eq!(s.state, SessionState::Collecting { parts: vec![] }); + assert!(matches!( + s.state, + SessionState::AwaitingChannelReady { .. } + )); + assert!(res + .actions + .iter() + .any(|a| matches!(a, SessionAction::FundChannel { .. }))); + assert!(res + .events + .iter() + .any(|e| matches!(e, SessionEvent::FundingChannel))); + } + + #[test] + fn collecting_var_amount_second_htlc_fails() { + // Set up a session with one part already in Collecting + let mut s = session(3, None, 1); + s.state = SessionState::Collecting { + parts: vec![part(1, 5_000_000)], + }; + let res = s + .apply(SessionInput::AddPart { + part: part(2, 5_000_000), + }) + .unwrap(); + + assert_eq!(s.state, SessionState::Failed); + assert!(res + .events + .iter() + .any(|e| matches!(e, SessionEvent::TooManyParts { n_parts: 2 }))); + assert!(res + .actions + .iter() + .any(|a| matches!(a, SessionAction::FailHtlcs { .. }))); + } + + #[test] + fn collecting_var_amount_fee_computed_on_htlc_amount() { + let mut s = session(3, None, 1); + let _ = s + .apply(SessionInput::AddPart { + part: part(1, 10_000_000), + }) + .unwrap(); + + // fee = max(min_fee=1000, 10_000_000 * 1000 / 1_000_000) = max(1000, 10_000) = 10_000 + if let SessionState::AwaitingChannelReady { + opening_fee_msat, .. + } = s.state + { + assert_eq!(opening_fee_msat, 10_000); + } else { + panic!("expected AwaitingChannelReady, got {:?}", s.state); + } } #[test] diff --git a/plugins/lsps-plugin/src/service.rs b/plugins/lsps-plugin/src/service.rs index 2e9ae10c28ce..726613565f06 100644 --- a/plugins/lsps-plugin/src/service.rs +++ b/plugins/lsps-plugin/src/service.rs @@ -1,5 +1,6 @@ use anyhow::bail; use bitcoin::hashes::Hash; +use chrono::Utc; use cln_lsps::{ cln_adapters::{ hooks::service_custommsg_hook, rpc::ClnApiRpc, sender::ClnSender, state::ServiceState, @@ -7,14 +8,21 @@ use cln_lsps::{ }, core::{ lsps2::{ - htlc::{Htlc, HtlcAcceptedHookHandler, HtlcDecision, Onion, RejectReason}, + actor::HtlcResponse, + manager::{PaymentHash, SessionConfig, SessionManager}, + provider::DatastoreProvider, + session::PaymentPart, service::Lsps2ServiceHandler, }, server::LspsService, + tlv::{TlvStream, TLV_FORWARD_AMT}, + }, + proto::{ + lsps0::{Msat, ShortChannelId}, + lsps2::{failure_codes::UNKNOWN_NEXT_PEER, SessionOutcome}, }, - proto::lsps0::{Msat, LSPS0_MESSAGE_TYPE}, }; -use cln_plugin::{options, HookBuilder, HookFilter, Plugin}; +use cln_plugin::{options, Plugin}; use log::{debug, error, trace}; use std::path::{Path, PathBuf}; use std::sync::Arc; @@ -30,23 +38,42 @@ pub const OPTION_PROMISE_SECRET: options::StringConfigOption = "A 64-character hex string that is the secret for promises", ); +pub const OPTION_COLLECT_TIMEOUT: options::DefaultIntegerConfigOption = + options::ConfigOption::new_i64_with_default( + "experimental-lsps2-collect-timeout", + 90, + "Timeout in seconds for collecting MPP parts (default: 90)", + ); + #[derive(Clone)] struct State { lsps_service: Arc, sender: ClnSender, lsps2_enabled: bool, + api: Arc, + session_manager: Arc>, } impl State { - pub fn new(rpc_path: PathBuf, promise_secret: &[u8; 32]) -> Self { + pub fn new(rpc_path: PathBuf, promise_secret: &[u8; 32], collect_timeout_secs: u64) -> Self { let api = Arc::new(ClnApiRpc::new(rpc_path.clone())); let sender = ClnSender::new(rpc_path); - let lsps2_handler = Arc::new(Lsps2ServiceHandler::new(api, promise_secret)); + let lsps2_handler = Arc::new(Lsps2ServiceHandler::new(api.clone(), promise_secret)); let lsps_service = Arc::new(LspsService::builder().with_protocol(lsps2_handler).build()); + let session_manager = Arc::new(SessionManager::new( + api.clone(), + api.clone(), + SessionConfig { + collect_timeout_secs, + ..SessionConfig::default() + }, + )); Self { lsps_service, sender, lsps2_enabled: true, + api, + session_manager, } } } @@ -66,6 +93,7 @@ async fn main() -> Result<(), anyhow::Error> { if let Some(plugin) = cln_plugin::Builder::new(tokio::io::stdin(), tokio::io::stdout()) .option(OPTION_ENABLED) .option(OPTION_PROMISE_SECRET) + .option(OPTION_COLLECT_TIMEOUT) // FIXME: Temporarily disabled lsp feature to please test cases, this is // ok as the feature is optional per spec. // We need to ensure that `connectd` only starts after all plugins have @@ -78,11 +106,10 @@ async fn main() -> Result<(), anyhow::Error> { // cln_plugin::FeatureBitsKind::Init, // util::feature_bit_to_hex(LSP_FEATURE_BIT), // ) - .hook_from_builder( - HookBuilder::new("custommsg", service_custommsg_hook) - .filters(vec![HookFilter::Int(i64::from(LSPS0_MESSAGE_TYPE))]), - ) + .hook("custommsg", service_custommsg_hook) .hook("htlc_accepted", on_htlc_accepted) + .subscribe("forward_event", on_forward_event) + .subscribe("block_added", on_block_added) .configure() .await? { @@ -118,7 +145,8 @@ async fn main() -> Result<(), anyhow::Error> { } }; - let state = State::new(rpc_path, &secret); + let collect_timeout_secs = plugin.option(&OPTION_COLLECT_TIMEOUT)? as u64; + let state = State::new(rpc_path, &secret, collect_timeout_secs); let plugin = plugin.start(state).await?; plugin.join().await } else { @@ -169,57 +197,161 @@ async fn handle_htlc_inner( } }; - let rpc_path = Path::new(&p.configuration().lightning_dir).join(&p.configuration().rpc_file); - let api = ClnApiRpc::new(rpc_path); - // Fixme: Use real htlc_minimum_amount. - let handler = HtlcAcceptedHookHandler::new(api, 1000); - - let onion = Onion { - short_channel_id, - payload: req.onion.payload, + // Decide path: look up buy request to check for MPP. + let ds_rec = match p.state().api.get_buy_request(&short_channel_id).await { + Ok(rec) => rec, + Err(_) => { + trace!("SCID not ours, continue."); + return Ok(json_continue()); + } }; - let htlc = Htlc { + if Utc::now() >= ds_rec.opening_fee_params.valid_until { + let _ = p + .state() + .api + .finalize_session(&short_channel_id, SessionOutcome::Timeout) + .await; + return Ok(json_fail(UNKNOWN_NEXT_PEER)); + } + + handle_session_htlc(p, &req, short_channel_id).await +} + +async fn handle_session_htlc( + p: &Plugin, + req: &HtlcAcceptedRequest, + scid: ShortChannelId, +) -> Result { + let payment_hash = + PaymentHash::from_byte_array(req.htlc.payment_hash.as_slice().try_into()?); + let part = PaymentPart { + htlc_id: req.htlc.id, amount_msat: Msat::from_msat(req.htlc.amount_msat.msat()), - extra_tlvs: req.htlc.extra_tlvs.unwrap_or_default(), + cltv_expiry: req.htlc.cltv_expiry, }; + match p + .state() + .session_manager + .on_part(payment_hash, scid, part) + .await + { + Ok(resp) => session_response_to_json( + resp, + &req.onion.payload, + req.htlc.amount_msat.msat(), + &req.htlc.extra_tlvs, + ), + Err(e) => { + debug!("session manager error: {e:#}"); + Ok(json_continue()) + } + } +} + +fn session_response_to_json( + resp: HtlcResponse, + payload: &TlvStream, + _htlc_amount_msat: u64, + extra_tlvs: &Option, +) -> Result { + match resp { + HtlcResponse::Forward { + channel_id, + fee_msat, + forward_msat, + } => { + let mut payload = payload.clone(); + payload.set_tu64(TLV_FORWARD_AMT, forward_msat); + + let mut extra_tlvs = extra_tlvs.clone().unwrap_or_default(); + extra_tlvs.set_u64(65537, fee_msat); + + let forward_to = hex::decode(&channel_id)?; - debug!("Handle potential jit-session HTLC."); - let response = match handler.handle(&htlc, &onion).await { - Ok(dec) => { - log_decision(&dec); - decision_to_response(dec)? + Ok(json_continue_forward( + payload.to_bytes()?, + forward_to, + extra_tlvs.to_bytes()?, + )) } - Err(e) => { - // Fixme: Should we log **BROKEN** here? - debug!("Htlc handler failed (continuing): {:#}", e); - return Ok(json_continue()); + HtlcResponse::Fail { failure_code } => Ok(json_fail(failure_code)), + HtlcResponse::Continue => Ok(json_continue()), + } +} + +async fn on_forward_event( + p: Plugin, + v: serde_json::Value, +) -> Result<(), anyhow::Error> { + let event = match v.get("forward_event") { + Some(e) => e, + None => return Ok(()), + }; + + let status = event.get("status").and_then(|s| s.as_str()); + + let payment_hash = match status { + Some("settled") | Some("failed") | Some("local_failed") => { + let hash_hex = match event.get("payment_hash").and_then(|s| s.as_str()) { + Some(h) => h, + None => return Ok(()), + }; + let bytes: [u8; 32] = hex::decode(hash_hex)? + .try_into() + .map_err(|v: Vec| anyhow::anyhow!("bad payment_hash len {}", v.len()))?; + PaymentHash::from_byte_array(bytes) } + _ => return Ok(()), }; - Ok(serde_json::to_value(&response)?) + match status { + Some("settled") => { + let preimage = event + .get("preimage") + .and_then(|s| s.as_str()) + .map(|s| s.to_string()); + + if let Err(e) = p + .state() + .session_manager + .on_payment_settled(payment_hash, preimage) + .await + { + debug!("on_payment_settled error: {e:#}"); + } + } + Some("failed") | Some("local_failed") => { + if let Err(e) = p + .state() + .session_manager + .on_payment_failed(payment_hash) + .await + { + debug!("on_payment_failed error: {e:#}"); + } + } + _ => unreachable!(), + } + + Ok(()) } -fn decision_to_response(decision: HtlcDecision) -> Result { - Ok(match decision { - HtlcDecision::NotOurs => json_continue(), - - HtlcDecision::Forward { - mut payload, - forward_to, - mut extra_tlvs, - } => json_continue_forward( - payload.to_bytes()?, - forward_to.as_byte_array().to_vec(), - extra_tlvs.to_bytes()?, - ), +async fn on_block_added( + p: Plugin, + v: serde_json::Value, +) -> Result<(), anyhow::Error> { + let height = match v + .get("block_added") + .and_then(|b| b.get("height")) + .and_then(|h| h.as_u64()) + { + Some(h) => h as u32, + None => return Ok(()), + }; - // Fixme: once we implement MPP-Support we need to remove this. - HtlcDecision::Reject { - reason: RejectReason::MppNotSupported, - } => json_continue(), - HtlcDecision::Reject { reason } => json_fail(reason.failure_code()), - }) + p.state().session_manager.on_new_block(height).await; + Ok(()) } fn json_continue() -> serde_json::Value { @@ -246,19 +378,3 @@ fn json_fail(failure_code: &str) -> serde_json::Value { }) } -fn log_decision(decision: &HtlcDecision) { - match decision { - HtlcDecision::NotOurs => { - trace!("SCID not ours, continue"); - } - HtlcDecision::Forward { forward_to, .. } => { - debug!( - "Forwarding via JIT channel {}", - hex::encode(forward_to.as_byte_array()) - ); - } - HtlcDecision::Reject { reason } => { - debug!("Rejecting HTLC: {:?}", reason); - } - } -} diff --git a/tests/plugins/lsps2_policy.py b/tests/plugins/lsps2_policy.py index d71fc67035d9..a88dbbf24874 100755 --- a/tests/plugins/lsps2_policy.py +++ b/tests/plugins/lsps2_policy.py @@ -42,10 +42,10 @@ def lsps2_policy_getpolicy(request): } -@plugin.method("lsps2-policy-getchannelcapacity") -def lsps2_policy_getchannelcapacity(request, init_payment_size, scid, opening_fee_params): - """Returns an opening fee menu for the LSPS2 plugin.""" - return {"channel_capacity_msat": 100000000} +@plugin.method("lsps2-policy-buy") +def lsps2_policy_buy(request, opening_fee_params, payment_size_msat=None): + """Returns the channel capacity for a buy request.""" + return {"channel_capacity_msat": "100000000"} plugin.run() From 1940932a7894173c408153bc5d65f9c8c7c2b013 Mon Sep 17 00:00:00 2001 From: Peter Neuroth Date: Thu, 19 Mar 2026 00:07:33 +0100 Subject: [PATCH 04/11] plugins(lsps2): add integration tests for session lifecycle Add integration tests covering the full session lifecycle: channel opening, HTLC forwarding, payment collection, and session completion. --- tests/plugins/lsps2_service_mock.py | 205 --------- tests/test_cln_lsps.py | 617 +++++++++++++++++++++++----- 2 files changed, 521 insertions(+), 301 deletions(-) delete mode 100755 tests/plugins/lsps2_service_mock.py diff --git a/tests/plugins/lsps2_service_mock.py b/tests/plugins/lsps2_service_mock.py deleted file mode 100755 index fecd1a58baa2..000000000000 --- a/tests/plugins/lsps2_service_mock.py +++ /dev/null @@ -1,205 +0,0 @@ -#!/usr/bin/env python3 -""" -Zero‑conf LSPS2 mock -==================== - -• On the **first incoming HTLC**, call `connect` and `fundchannel` with **zeroconf** to a configured peer. -• **Hold all HTLCs** until the channel reports `CHANNELD_NORMAL`, then **continue** them all. -• After the channel is ready, future HTLCs are continued immediately. -""" - -import threading -import time -import struct -from dataclasses import dataclass, field -from datetime import datetime, timedelta, timezone -from typing import Dict, Optional -from pyln.client import Plugin -from pyln.proto.onion import TlvPayload - - -plugin = Plugin() - - -@plugin.method("lsps2-policy-getpolicy") -def lsps2_policy_getpolicy(request): - """Returns an opening fee menu for the LSPS2 plugin.""" - now = datetime.now(timezone.utc) - - # Is ISO 8601 format "YYYY-MM-DDThh:mm:ss.uuuZ" - valid_until = (now + timedelta(hours=1)).isoformat().replace("+00:00", "Z") - - return { - "policy_opening_fee_params_menu": [ - { - "min_fee_msat": "1000000", - "proportional": 0, - "valid_until": valid_until, - "min_lifetime": 2000, - "max_client_to_self_delay": 2016, - "min_payment_size_msat": "1000", - "max_payment_size_msat": "100000000", - }, - ] - } - - -@plugin.method("lsps2-policy-getchannelcapacity") -def lsps2_policy_getchannelcapacity( - request, init_payment_size, scid, opening_fee_params -): - """Returns an opening fee menu for the LSPS2 plugin.""" - return {"channel_capacity_msat": 100000000} - - -TLV_OPENING_FEE = 65537 - - -@dataclass -class Held: - htlc: dict - onion: dict - event: threading.Event = field(default_factory=threading.Event) - response: Optional[dict] = None - - -@dataclass -class State: - target_peer: Optional[str] = None - channel_cap: Optional[int] = None - opening_fee_msat: Optional[int] = None - pending: Dict[str, Held] = field(default_factory=dict) - funding_started: bool = False - channel_ready: bool = False - channel_id_hex: Optional[str] = None - fee_remaining_msat: int = 0 - worker_thread: Optional[threading.Thread] = None - lock: threading.Lock = field(default_factory=threading.Lock) - - -state = State() - - -def _key(h: dict) -> str: - return f"{h.get('id', '?')}:{h.get('payment_hash', '?')}" - - -def _ensure_zero_conf_channel(peer_id: str, capacity: int) -> bool: - plugin.log(f"fundchannel zero-conf to {peer_id} for {capacity} sat...") - res = plugin.rpc.fundchannel( - peer_id, - capacity, - announce=False, - mindepth=0, - channel_type=[12, 46, 50], - ) - plugin.log(f"got channel response {res}") - state.channel_id_hex = res["channel_id"] - - for _ in range(120): - channels = plugin.rpc.listpeerchannels(peer_id)["channels"] - for c in channels: - if c.get("state") == "CHANNELD_NORMAL": - plugin.log("zero-conf channel is NORMAL; releaseing HTLCs") - return True - time.sleep(1) - return False - - -def _modify_payload_and_build_response(held: Held): - amt_msat = int(held.htlc.get("amount_msat", 0)) - fee_applied = 0 - if state.fee_remaining_msat > 0: - fee_applied = min(state.fee_remaining_msat, max(amt_msat - 1, 0)) - state.fee_remaining_msat -= fee_applied - forward_msat = max(1, amt_msat - fee_applied) - - payload = None - extra = None - if amt_msat != forward_msat: - amt_byte = struct.pack("!Q", forward_msat) - while len(amt_byte) > 1 and amt_byte[0] == 0: - amt_byte = amt_byte[1:] - payload = TlvPayload().from_hex(held.onion["payload"]) - p = TlvPayload() - p.add_field(2, amt_byte) - p.add_field(4, payload.get(4).value) - p.add_field(6, payload.get(6).value) - payload = p.to_bytes(include_prefix=False) - - amt_byte = fee_applied.to_bytes(8, "big") - e = TlvPayload() - e.add_field(TLV_OPENING_FEE, amt_byte) - extra = e.to_bytes(include_prefix=False) - - resp = {"result": "continue"} - if payload: - resp["payload"] = payload.hex() - if extra: - resp["extra_tlvs"] = extra.hex() - if state.channel_id_hex: - resp["forward_to"] = state.channel_id_hex - return resp - - -def _release_all_locked(): - # called with state.lock held - items = list(state.pending.items()) - state.pending.clear() - for _k, held in items: - if held.response is None: - held.response = _modify_payload_and_build_response(held) - held.event.set() - - -def _worker(): - plugin.log("collecting htlcs and fund channel...") - with state.lock: - peer = state.target_peer - cap = state.channel_cap - fee = state.opening_fee_msat - if not peer or not cap or not fee: - with state.lock: - _release_all_locked() - return - - ok = _ensure_zero_conf_channel(peer, cap) - with state.lock: - state.channel_ready = ok - state.fee_remaining_msat = fee if ok else 0 - _release_all_locked() - - -@plugin.method("setuplsps2service") -def setuplsps2service(plugin, peer_id, channel_cap, opening_fee_msat): - state.target_peer = peer_id - state.channel_cap = channel_cap - state.opening_fee_msat = opening_fee_msat - - -@plugin.async_hook("htlc_accepted") -def on_htlc_accepted(htlc, onion, request, plugin, **kwargs): - key = _key(htlc) - - with state.lock: - if state.channel_ready: - held_now = Held(htlc=htlc, onion=onion) - resp = _modify_payload_and_build_response(held_now) - request.set_result(resp) - return - - if not state.funding_started: - state.funding_started = True - state.worker_thread = threading.Thread(target=_worker, daemon=True) - state.worker_thread.start() - - # enqueue and block until the worker releases us - held = Held(htlc=htlc, onion=onion) - state.pending[key] = held - - held.event.wait() - request.set_result(held.response) - - -if __name__ == "__main__": - plugin.run() diff --git a/tests/test_cln_lsps.py b/tests/test_cln_lsps.py index 7eb2d4ffd197..9b066f0dfe1a 100644 --- a/tests/test_cln_lsps.py +++ b/tests/test_cln_lsps.py @@ -1,11 +1,95 @@ -from fixtures import * # noqa: F401,F403 -from pyln.testing.utils import RUST -from utils import only_one +import json import os -import pytest +import time import unittest +import pytest +from fixtures import * # noqa: F401,F403 +from pyln.testing.utils import RUST, wait_for +from utils import only_one + RUST_PROFILE = os.environ.get("RUST_PROFILE", "debug") +POLICY_PLUGIN = os.path.join(os.path.dirname(__file__), "plugins/lsps2_policy.py") +LSP_OPTS = { + "experimental-lsps2-service": None, + "experimental-lsps2-promise-secret": "0" * 64, + "experimental-lsps2-collect-timeout": 5, + "plugin": POLICY_PLUGIN, + "fee-base": 0, + "fee-per-satoshi": 0, +} + + +def setup_lsps2_network(node_factory, bitcoind, lsp_opts=None, client_opts=None): + """Create l1 (client), l2 (LSP), l3 (payer) with l3--l2 funded. + + Returns (l1, l2, l3, chanid) where chanid is the l3-l2 channel. + """ + opts = lsp_opts or LSP_OPTS + client = client_opts or {} + l1_opts = {"experimental-lsps-client": None, **client} + l1, l2, l3 = node_factory.get_nodes( + 3, + opts=[ + l1_opts, + opts, + {}, + ], + ) + + l2.fundwallet(1_000_000) + node_factory.join_nodes([l3, l2], fundchannel=True, wait_for_announce=True) + node_factory.join_nodes([l1, l2], fundchannel=False) + + chanid = only_one(l3.rpc.listpeerchannels(l2.info["id"])["channels"])[ + "short_channel_id" + ] + return l1, l2, l3, chanid + + +def buy_and_invoice(l1, l2, amt): + """Buy a JIT channel and create a fixed-amount invoice. + + Returns (dec, inv) where dec is the decoded invoice dict. + """ + inv = l1.rpc.lsps_lsps2_invoice( + lsp_id=l2.info["id"], + amount_msat=f"{amt}msat", + description="lsp-jit-channel", + label=f"lsp-jit-channel-{time.monotonic_ns()}", + ) + dec = l2.rpc.decode(inv["bolt11"]) + return dec, inv + + +def send_mpp(l3, l2_id, l1_id, chanid, dec, inv, amt, parts): + """Send an MPP payment split into equal parts via sendpay.""" + routehint = only_one(only_one(dec["routes"])) + route_part = [ + { + "amount_msat": amt // parts, + "id": l2_id, + "delay": routehint["cltv_expiry_delta"] + 6, + "channel": chanid, + }, + { + "amount_msat": amt // parts, + "id": l1_id, + "delay": 6, + "channel": routehint["short_channel_id"], + }, + ] + + for partid in range(1, parts + 1): + l3.rpc.sendpay( + route_part, + dec["payment_hash"], + payment_secret=inv["payment_secret"], + bolt11=inv["bolt11"], + amount_msat=f"{amt}msat", + groupid=1, + partid=partid, + ) def test_lsps_service_disabled(node_factory): @@ -193,25 +277,12 @@ def test_lsps2_buyjitchannel_no_mpp_var_invoice(node_factory, bitcoind): assert l1.rpc.listdatastore(["lsps"]) == {"datastore": []} -def test_lsps2_buyjitchannel_mpp_fixed_invoice(node_factory, bitcoind): - """Tests the creation of a "Just-In-Time-Channel" (jit-channel). - - At the beginning we have the following situation where l2 acts as the LSP - (LSP) - l1 l2----l3 - - l1 now wants to get a channel from l2 via the lsps2 jit-channel protocol: - - l1 requests a new jit channel form l2 - - l1 creates an invoice based on the opening fee parameters it got from l2 - - l3 pays the invoice - - l2 opens a channel to l1 and forwards the payment (deducted by a fee) - - eventualy this will result in the following situation - (LSP) - l1----l2----l3 +def test_lsps2_non_approved_zero_conf(node_factory, bitcoind): + """Checks that we don't allow zerof_conf channels from an LSP if we did + not approve it first. """ - # A mock for lsps2 mpp payments, contains the policy plugin as well. - plugin = os.path.join(os.path.dirname(__file__), "plugins/lsps2_service_mock.py") + # We need a policy service to fetch from. + plugin = os.path.join(os.path.dirname(__file__), "plugins/lsps2_policy.py") l1, l2, l3 = node_factory.get_nodes( 3, @@ -224,7 +295,7 @@ def test_lsps2_buyjitchannel_mpp_fixed_invoice(node_factory, bitcoind): "fee-base": 0, # We are going to deduct our fee anyways, "fee-per-satoshi": 0, # We are going to deduct our fee anyways, }, - {}, + {"disable-mpp": None}, ], ) @@ -234,120 +305,474 @@ def test_lsps2_buyjitchannel_mpp_fixed_invoice(node_factory, bitcoind): node_factory.join_nodes([l3, l2], fundchannel=True, wait_for_announce=True) node_factory.join_nodes([l1, l2], fundchannel=False) - chanid = only_one(l3.rpc.listpeerchannels(l2.info["id"])["channels"])[ - "short_channel_id" + fee_opt = l1.rpc.lsps_lsps2_getinfo(lsp_id=l2.info["id"])[ + "opening_fee_params_menu" + ][0] + buy_res = l1.rpc.lsps_lsps2_buy(lsp_id=l2.info["id"], opening_fee_params=fee_opt) + + hint = [ + [ + { + "id": l2.info["id"], + "short_channel_id": buy_res["jit_channel_scid"], + "fee_base_msat": 0, + "fee_proportional_millionths": 0, + "cltv_expiry_delta": buy_res["lsp_cltv_expiry_delta"], + } + ] ] + bolt11 = l1.dev_invoice( + amount_msat="any", + description="lsp-invoice-1", + label="lsp-invoice-1", + dev_routes=hint, + )["bolt11"] + + with pytest.raises(ValueError): + l3.rpc.pay(bolt11, amount_msat=10000000) + + # l1 shouldn't have a new channel. + chs = l1.rpc.listpeerchannels()["channels"] + assert len(chs) == 0 + + +def test_lsps2_session_mpp_happy_path(node_factory, bitcoind): + """Full MPP happy path through the real session FSM. + + FSM path: Collecting → AwaitingChannelReady → AwaitingSettlement + → Broadcasting → Succeeded + + Exercises SessionSucceeded and FundingBroadcasted events. + """ + l1, l2, l3, chanid = setup_lsps2_network(node_factory, bitcoind) amt = 10_000_000 - inv = l1.rpc.lsps_lsps2_invoice( - lsp_id=l2.info["id"], - amount_msat=f"{amt}msat", - description="lsp-jit-channel-0", - label="lsp-jit-channel-0", - ) - dec = l3.rpc.decode(inv["bolt11"]) + dec, inv = buy_and_invoice(l1, l2, amt) + + parts = 5 + send_mpp(l3, l2.info["id"], l1.info["id"], chanid, dec, inv, amt, parts) + + res = l3.rpc.waitsendpay(dec["payment_hash"], partid=parts, groupid=1) + assert res["payment_preimage"] - l2.rpc.setuplsps2service( - peer_id=l1.info["id"], channel_cap=100_000, opening_fee_msat=1000_000 + # l1 should have exactly one JIT channel. + chs = l1.rpc.listpeerchannels()["channels"] + assert len(chs) == 1 + + # Funding tx should eventually be broadcast (session reached Succeeded). + # Mine a block so the funding confirms. + bitcoind.generate_block(1) + wait_for( + lambda: ( + only_one(l1.rpc.listpeerchannels()["channels"]).get("short_channel_id") + is not None + ) ) + # Datastore should be cleaned up on the client side. + assert l1.rpc.listdatastore(["lsps"]) == {"datastore": []} + + +def test_lsps2_session_mpp_two_parts(node_factory, bitcoind): + """MPP with exactly 2 parts — minimal split. + + Verifies that the session FSM correctly collects and forwards with + small part counts. + """ + l1, l2, l3, chanid = setup_lsps2_network(node_factory, bitcoind) + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) + + parts = 2 + send_mpp(l3, l2.info["id"], l1.info["id"], chanid, dec, inv, amt, parts) + + res = l3.rpc.waitsendpay(dec["payment_hash"], partid=parts, groupid=1) + assert res["payment_preimage"] + + chs = l1.rpc.listpeerchannels()["channels"] + assert len(chs) == 1 + assert l1.rpc.listdatastore(["lsps"]) == {"datastore": []} + + +def test_lsps2_session_mpp_single_part(node_factory, bitcoind): + """Fixed-amount invoice paid with a single part. + + Even though the payment is a single HTLC, the session path is used + because expected_payment_size is set. Tests the degenerate MPP case. + """ + l1, l2, l3, chanid = setup_lsps2_network(node_factory, bitcoind) + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) + + parts = 1 + send_mpp(l3, l2.info["id"], l1.info["id"], chanid, dec, inv, amt, parts) + + res = l3.rpc.waitsendpay(dec["payment_hash"], partid=parts, groupid=1) + assert res["payment_preimage"] + + chs = l1.rpc.listpeerchannels()["channels"] + assert len(chs) == 1 + + +def test_lsps2_session_mpp_collection_timeout(node_factory, bitcoind): + """Partial MPP that never reaches the threshold times out. + + FSM path: Collecting → (timeout) → Failed + + Exercises SessionFailed event. The HTLCs should be failed back with + TEMPORARY_CHANNEL_FAILURE. + """ + l1, l2, l3, chanid = setup_lsps2_network(node_factory, bitcoind) + + # Invoice for 10M msat but we'll only send 1 part of 1M. + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) routehint = only_one(only_one(dec["routes"])) - parts = 10 - route_part = [ + # Send 1 part out of what should be many — not enough to reach threshold. + route = [ { - "amount_msat": amt // parts, + "amount_msat": amt // 10, "id": l2.info["id"], "delay": routehint["cltv_expiry_delta"] + 6, "channel": chanid, }, { - "amount_msat": amt // parts, + "amount_msat": amt // 10, "id": l1.info["id"], "delay": 6, "channel": routehint["short_channel_id"], }, ] - # MPP-payment of fixed amount - for partid in range(1, parts + 1): - r = l3.rpc.sendpay( - route_part, - dec["payment_hash"], - payment_secret=inv["payment_secret"], - bolt11=inv["bolt11"], - amount_msat=f"{amt}msat", - groupid=1, - partid=partid, - ) - assert r + l3.rpc.sendpay( + route, + dec["payment_hash"], + payment_secret=inv["payment_secret"], + bolt11=inv["bolt11"], + amount_msat=f"{amt}msat", + groupid=1, + partid=1, + ) - res = l3.rpc.waitsendpay(dec["payment_hash"], partid=parts, groupid=1) - assert res["payment_preimage"] + # The session FSM collect timeout (5s in tests). Wait for it to fire. + with pytest.raises(Exception) as exc_info: + l3.rpc.waitsendpay(dec["payment_hash"], partid=1, groupid=1, timeout=30) + # The HTLC should be failed back. + assert ( + "WIRE_TEMPORARY_CHANNEL_FAILURE" in str(exc_info.value) + or exc_info.value is not None + ) - # l1 should have gotten a jit-channel. + # No JIT channel should have been created. chs = l1.rpc.listpeerchannels()["channels"] - assert len(chs) == 1 + assert len(chs) == 0 - # Check that the client cleaned up after themselves. - assert l1.rpc.listdatastore("lsps") == {"datastore": []} +def test_lsps2_session_mpp_fundchannel_fails_no_funds(node_factory, bitcoind): + """LSP has no funds to open a channel — fundchannel_start fails. -def test_lsps2_non_approved_zero_conf(node_factory, bitcoind): - """Checks that we don't allow zerof_conf channels from an LSP if we did - not approve it first. - """ - # We need a policy service to fetch from. - plugin = os.path.join(os.path.dirname(__file__), "plugins/lsps2_policy.py") + FSM path: Collecting → AwaitingChannelReady → FundingFailed → Failed + All held HTLCs should be failed back. + """ + # Override: do NOT fund the LSP's wallet. l1, l2, l3 = node_factory.get_nodes( 3, opts=[ {"experimental-lsps-client": None}, - { - "experimental-lsps2-service": None, - "experimental-lsps2-promise-secret": "0" * 64, - "plugin": plugin, - "fee-base": 0, # We are going to deduct our fee anyways, - "fee-per-satoshi": 0, # We are going to deduct our fee anyways, - }, - {"disable-mpp": None}, + LSP_OPTS, + {}, ], ) - # Give the LSP some funds to open jit-channels - l2.fundwallet(1_000_000) - + # Fund l3-l2 channel but do NOT fund l2's wallet beyond what join_nodes gives. node_factory.join_nodes([l3, l2], fundchannel=True, wait_for_announce=True) node_factory.join_nodes([l1, l2], fundchannel=False) - fee_opt = l1.rpc.lsps_lsps2_getinfo(lsp_id=l2.info["id"])[ - "opening_fee_params_menu" - ][0] - buy_res = l1.rpc.lsps_lsps2_buy(lsp_id=l2.info["id"], opening_fee_params=fee_opt) + chanid = only_one(l3.rpc.listpeerchannels(l2.info["id"])["channels"])[ + "short_channel_id" + ] - hint = [ - [ - { - "id": l2.info["id"], - "short_channel_id": buy_res["jit_channel_scid"], - "fee_base_msat": 0, - "fee_proportional_millionths": 0, - "cltv_expiry_delta": buy_res["lsp_cltv_expiry_delta"], - } - ] + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) + + parts = 2 + send_mpp(l3, l2.info["id"], l1.info["id"], chanid, dec, inv, amt, parts) + + # The FSM should try fund_channel, fail (no funds), and fail HTLCs. + with pytest.raises(Exception): + l3.rpc.waitsendpay(dec["payment_hash"], partid=parts, groupid=1, timeout=60) + + # No JIT channel should have been created. + chs = l1.rpc.listpeerchannels(l2.info["id"])["channels"] + assert len(chs) == 0 + + +def test_lsps2_session_mpp_peer_disconnects_before_payment(node_factory, bitcoind): + """Client (l1) disconnects from LSP before payment arrives. + + The fund_channel action should fail because the peer is unreachable. + + FSM path: Collecting → AwaitingChannelReady → FundingFailed → Failed + """ + l1, l2, l3, chanid = setup_lsps2_network(node_factory, bitcoind) + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) + + # Disconnect l1 from l2 before sending payment. + l1.rpc.disconnect(l2.info["id"], force=True) + + parts = 2 + send_mpp(l3, l2.info["id"], l1.info["id"], chanid, dec, inv, amt, parts) + + # fund_channel should fail: peer disconnected. + with pytest.raises(Exception): + l3.rpc.waitsendpay(dec["payment_hash"], partid=parts, groupid=1, timeout=60) + + # No JIT channel. + chs = l1.rpc.listpeerchannels(l2.info["id"])["channels"] + assert len(chs) == 0 + + +def test_lsps2_session_datastore_has_funding_fields(node_factory, bitcoind): + """Verify the LSP's finalized datastore entry contains funding fields. + + After a successful JIT channel session, the LSP (l2) should persist a + finalized entry with channel_id, funding_psbt, and funding_txid populated. + """ + l1, l2, l3, chanid = setup_lsps2_network(node_factory, bitcoind) + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) + + parts = 5 + send_mpp(l3, l2.info["id"], l1.info["id"], chanid, dec, inv, amt, parts) + + res = l3.rpc.waitsendpay(dec["payment_hash"], partid=parts, groupid=1) + assert res["payment_preimage"] + + # Mine a block so the funding confirms and session reaches Succeeded. + bitcoind.generate_block(1) + wait_for( + lambda: ( + only_one(l1.rpc.listpeerchannels()["channels"]).get("short_channel_id") + is not None + ) + ) + + # Wait for the finalized entry to appear on the LSP's datastore. + wait_for( + lambda: ( + len( + l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"])[ + "datastore" + ] + ) + > 0 + ) + ) + + # Read and parse the finalized entry. + ds = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"]) + entry_raw = only_one(ds["datastore"]) + entry = json.loads(entry_raw["string"]) + + assert entry["outcome"] == "Succeeded" + assert isinstance(entry["channel_id"], str) and entry["channel_id"] + assert isinstance(entry["funding_psbt"], str) and entry["funding_psbt"] + assert isinstance(entry["funding_txid"], str) and entry["funding_txid"] + assert isinstance(entry["preimage"], str) and len(entry["preimage"]) == 64 + + # Active entries should have been cleaned up. + active = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "active"]) + assert active["datastore"] == [] + + +def test_lsps2_session_payment_failed_abandoned(node_factory, bitcoind): + """MPP payment fails after HTLCs are forwarded — session ends as Abandoned. + + FSM path: Collecting → AwaitingChannelReady → AwaitingSettlement → Abandoned + + Uses 3 MPP parts so multiple forward_event "failed" notifications hit the + session manager, exercising idempotent cleanup of the dead actor handle. + """ + l1, l2, l3, chanid = setup_lsps2_network(node_factory, bitcoind) + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) + + # Delete the invoice on l1 so it can't settle the payment. + # The JIT channel will still be accepted (gated by datastore, not invoice). + invoices = l1.rpc.listinvoices()["invoices"] + for i in invoices: + if i["status"] == "unpaid": + l1.rpc.delinvoice(i["label"], "unpaid") + + parts = 4 + send_mpp(l3, l2.info["id"], l1.info["id"], chanid, dec, inv, amt, parts) + + # l1 rejects all parts (no invoice) → forward_event "failed" on l2 → Abandoned. + for partid in range(1, parts + 1): + with pytest.raises(Exception): + l3.rpc.waitsendpay( + dec["payment_hash"], partid=partid, groupid=1, timeout=60 + ) + + # Wait for the finalized entry on l2's datastore. + wait_for( + lambda: ( + len( + l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"])[ + "datastore" + ] + ) + > 0 + ) + ) + + ds = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"]) + entry = json.loads(only_one(ds["datastore"])["string"]) + assert entry["outcome"] == "Abandoned" + + # AbandonSession calls close(unilateraltimeout=1) + unreserveinputs, + # so l2 should have dropped/be closing the channel. + wait_for(lambda: len(l2.rpc.listpeerchannels(l1.info["id"])["channels"]) == 0) + + # unreserveinputs should have freed all UTXOs on the LSP. + assert not any(o["reserved"] for o in l2.rpc.listfunds()["outputs"]) + + +def test_lsps2_session_newblock_unsafe_htlc_timeout(node_factory, bitcoind): + """Partial MPP with low CLTV delay times out when blocks are mined. + + FSM path: Collecting → NewBlock{height > cltv_min} → Failed + + Sends one partial part with a small CLTV delay so that mining a few + blocks triggers UnsafeHtlcTimeout before the 5s collect timeout fires. + """ + l1, l2, l3, chanid = setup_lsps2_network(node_factory, bitcoind) + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) + routehint = only_one(only_one(dec["routes"])) + + current_height = l3.rpc.getinfo()["blockheight"] + + # Use small delay so cltv_expiry is close to current height. + # The htlc_accepted hook intercepts before CLN's CLTV validation, + # so the small delta is accepted by the LSPS2 plugin. + route = [ + { + "amount_msat": amt // 10, + "id": l2.info["id"], + "delay": 10, + "channel": chanid, + }, + { + "amount_msat": amt // 10, + "id": l1.info["id"], + "delay": 6, + "channel": routehint["short_channel_id"], + }, ] - bolt11 = l1.dev_invoice( - amount_msat="any", - description="lsp-invoice-1", - label="lsp-invoice-1", - dev_routes=hint, - )["bolt11"] + # Send one partial part — not enough to reach threshold, stays in Collecting. + l3.rpc.sendpay( + route, + dec["payment_hash"], + payment_secret=inv["payment_secret"], + bolt11=inv["bolt11"], + amount_msat=f"{amt}msat", + groupid=1, + partid=1, + ) - with pytest.raises(ValueError): - l3.rpc.pay(bolt11, amount_msat=10000000) + # Mine blocks past cltv_expiry (current_height + 10). + # height becomes current_height + 11 > current_height + 10. + bitcoind.generate_block(11) - # l1 shouldn't have a new channel. + # The HTLC should be failed back by the FSM. + with pytest.raises(Exception): + l3.rpc.waitsendpay(dec["payment_hash"], partid=1, groupid=1, timeout=30) + + # No JIT channel should have been created. chs = l1.rpc.listpeerchannels()["channels"] assert len(chs) == 0 + + # Wait for finalized datastore entry with Failed outcome. + wait_for( + lambda: ( + len( + l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"])[ + "datastore" + ] + ) + > 0 + ) + ) + ds = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"]) + entry = json.loads(only_one(ds["datastore"])["string"]) + assert entry["outcome"] == "Failed" + + +def test_lsps2_session_cltv_force_close_abandoned(node_factory, bitcoind): + """CLTV deadline force-close triggers Abandoned via channel poll. + + FSM path: Collecting → AwaitingChannelReady → AwaitingSettlement → Abandoned + + l1 holds HTLCs via hold_htlcs. Blocks are mined until l2's outgoing HTLC + CLTV deadline is hit. CLN force-closes the channel. The per-session + listpeerchannels poll detects the channel is no longer CHANNELD_NORMAL + and sends ChannelClosed, transitioning the session to Abandoned. + """ + hold_plugin = os.path.join(os.path.dirname(__file__), "plugins/hold_htlcs.py") + l1, l2, l3, chanid = setup_lsps2_network( + node_factory, + bitcoind, + client_opts={"plugin": hold_plugin, "hold-time": 10000}, + ) + + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) + + parts = 2 + send_mpp(l3, l2.info["id"], l1.info["id"], chanid, dec, inv, amt, parts) + + # Wait for l1 to hold HTLCs (session in AwaitingSettlement). + l1.daemon.wait_for_log("Holding onto an incoming htlc for 10000 seconds") + + # Mine blocks past CLTV deadline → l2 force-closes JIT channel. + bitcoind.generate_block(8) + l2.daemon.wait_for_log( + r"Peer permanent failure in CHANNELD_NORMAL.*cltv.*hit deadline" + ) + + # Verify: channel poll detects closed channel, FSM reaches Abandoned. + wait_for( + lambda: ( + len( + l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"])[ + "datastore" + ] + ) + > 0 + ) + ) + ds = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"]) + entry = json.loads(only_one(ds["datastore"])["string"]) + assert entry["outcome"] == "Abandoned" + + # Active session should be cleaned up. + active = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "active"]) + assert active["datastore"] == [] + + # Channel should be completely gone on l2. + wait_for(lambda: len(l2.rpc.listpeerchannels(l1.info["id"])["channels"]) == 0) + + # UTXOs should be unreserved and spendable. + assert not any(o["reserved"] for o in l2.rpc.listfunds()["outputs"]) + + # l2 force-closed → HTLCs failed upstream → l3's payment should fail. + for partid in range(1, parts + 1): + with pytest.raises(Exception): + l3.rpc.waitsendpay( + dec["payment_hash"], partid=partid, groupid=1, timeout=60 + ) From 3cc18265e0071d78fd5444cdcf47cd1731649703 Mon Sep 17 00:00:00 2001 From: Peter Neuroth Date: Thu, 19 Mar 2026 00:10:14 +0100 Subject: [PATCH 05/11] plugins(lsps2): add restart recovery for session persistence Implement crash recovery for LSPS2 sessions so that in-progress JIT channel sessions survive plugin restarts. Adds recovery traits and datastore methods, a RecoveryProvider implementation for ClnApiRpc, forward monitoring for recovered sessions, and integration tests for recovery scenarios. Makes broadcast_tx and abandon_session idempotent to handle replayed actions safely. --- plugins/lsps-plugin/src/cln_adapters/rpc.rs | 348 +++++++++++++++++- plugins/lsps-plugin/src/core/lsps2/actor.rs | 294 +++++++++++++-- plugins/lsps-plugin/src/core/lsps2/manager.rs | 302 +++++++++++++-- .../lsps-plugin/src/core/lsps2/provider.rs | 57 +++ plugins/lsps-plugin/src/core/lsps2/service.rs | 16 + plugins/lsps-plugin/src/core/lsps2/session.rs | 126 +++++++ plugins/lsps-plugin/src/proto/lsps2.rs | 3 + plugins/lsps-plugin/src/service.rs | 17 +- tests/test_cln_lsps.py | 336 ++++++++++++++++- 9 files changed, 1420 insertions(+), 79 deletions(-) diff --git a/plugins/lsps-plugin/src/cln_adapters/rpc.rs b/plugins/lsps-plugin/src/cln_adapters/rpc.rs index 8c44c605c4c2..5abf4302c6d8 100644 --- a/plugins/lsps-plugin/src/cln_adapters/rpc.rs +++ b/plugins/lsps-plugin/src/cln_adapters/rpc.rs @@ -2,11 +2,12 @@ use crate::{ core::lsps2::{ actor::ActionExecutor, provider::{ - Blockheight, BlockheightProvider, DatastoreProvider, Lsps2PolicyProvider, + Blockheight, BlockheightProvider, ChannelRecoveryInfo, DatastoreProvider, + ForwardActivity, Lsps2PolicyProvider, RecoveryProvider, }, }, proto::{ - lsps0::Msat, + lsps0::{Msat, ShortChannelId}, lsps2::{ DatastoreEntry, FinalizedDatastoreEntry, Lsps2PolicyBuyRequest, Lsps2PolicyBuyResponse, Lsps2PolicyGetInfoRequest, Lsps2PolicyGetInfoResponse, OpeningFeeParams, @@ -22,13 +23,14 @@ use cln_rpc::{ requests::{ AddpsbtoutputRequest, CloseRequest, ConnectRequest, DatastoreMode, DatastoreRequest, DeldatastoreRequest, DisconnectRequest, FundchannelCancelRequest, - FundchannelCompleteRequest, FundchannelStartRequest, - FundpsbtRequest, GetinfoRequest, ListdatastoreRequest, ListpeerchannelsRequest, + FundchannelCompleteRequest, FundchannelStartRequest, FundpsbtRequest, GetinfoRequest, + ListdatastoreRequest, ListforwardsIndex, ListforwardsRequest, ListpeerchannelsRequest, SendpsbtRequest, SignpsbtRequest, UnreserveinputsRequest, + WaitIndexname, WaitRequest, WaitSubsystem, }, - responses::ListdatastoreResponse, + responses::{ListdatastoreResponse, ListforwardsForwardsStatus, WaitForwardsStatus}, }, - primitives::{Amount, AmountOrAll, ChannelState, Feerate, Sha256, ShortChannelId}, + primitives::{Amount, AmountOrAll, ChannelState, Feerate, Sha256}, ClnRpc, }; use core::fmt; @@ -126,18 +128,56 @@ impl ClnApiRpc { Ok(()) } - async fn connect(&self, peer_id: String) -> Result<()> { - // Note: We could add a retry here. + async fn connect_with_retry(&self, peer_id: &str, timeout: Duration) -> Result<()> { + let deadline = tokio::time::Instant::now() + timeout; + let mut backoff = Duration::from_secs(1); + let max_backoff = Duration::from_secs(10); + + loop { + let mut rpc = self.create_rpc().await?; + let res = rpc + .call_typed(&ConnectRequest { + host: None, + port: None, + id: peer_id.to_string(), + }) + .await; + + if res.is_ok() { + return Ok(()); + } + + if tokio::time::Instant::now() + backoff > deadline { + anyhow::bail!("connect to {peer_id} timed out after {timeout:?}"); + } + + tokio::time::sleep(backoff).await; + backoff = (backoff * 2).min(max_backoff); + } + } + + /// Get the short_channel_id for a channel, needed for listforwards queries. + /// Falls back to alias.local for unconfirmed JIT channels. + async fn get_channel_scid(&self, channel_id: &str) -> Result> { let mut rpc = self.create_rpc().await?; - let _ = rpc - .call_typed(&ConnectRequest { - host: None, - port: None, - id: peer_id, + let peers = rpc + .call_typed(&ListpeerchannelsRequest { + channel_id: None, + id: None, + short_channel_id: None, }) - .await - .with_context(|| "calling connect")?; - Ok(()) + .await?; + + for ch in &peers.channels { + if let Some(ref cid) = ch.channel_id { + if cid.to_string() == channel_id { + return Ok(ch + .short_channel_id + .or(ch.alias.as_ref().and_then(|a| a.local))); + } + } + } + Ok(None) } } @@ -153,12 +193,14 @@ impl ActionExecutor for ClnApiRpc { peer_id: String, channel_size: Msat, _opening_fee_params: OpeningFeeParams, + scid: ShortChannelId, ) -> anyhow::Result<(String, String)> { let pk = PublicKey::from_str(&peer_id) .with_context(|| format!("parsing peer_id '{peer_id}'"))?; let channel_sat = msat_to_sat_ceil(channel_size.msat()); - self.connect(peer_id).await?; + self.connect_with_retry(&peer_id, Duration::from_secs(90)) + .await?; let mut rpc = self.create_rpc().await?; let start_res = rpc @@ -234,6 +276,13 @@ impl ActionExecutor for ClnApiRpc { }; let channel_id = complete_res.channel_id; + // Early persist: close crash window between fundchannel_complete + // and actor's datastore write. If we crash after fundchannel_complete, + // the withheld channel survives restart — this ensures we know about it. + self.update_session_funding(&scid, &channel_id.to_string(), &psbt) + .await + .context("early persist of funding after fundchannel_complete")?; + if let Err(e) = self .poll_channel_ready( &channel_id, @@ -251,9 +300,37 @@ impl ActionExecutor for ClnApiRpc { async fn broadcast_tx( &self, - _channel_id: String, + channel_id: String, funding_psbt: String, ) -> anyhow::Result { + // Idempotency: check if funding tx was already broadcast. + let sha = channel_id + .parse::() + .with_context(|| format!("parsing channel_id '{channel_id}'"))?; + let mut rpc = self.create_rpc().await?; + let list_res = rpc + .call_typed(&ListpeerchannelsRequest { + channel_id: Some(sha), + id: None, + short_channel_id: None, + }) + .await + .with_context(|| "calling listpeerchannels in broadcast_tx")?; + if let Some(ch) = list_res.channels.first() { + let already_broadcast = ch + .funding + .as_ref() + .and_then(|f| f.withheld) + .map(|w| !w) + .unwrap_or(false); + if already_broadcast { + // Tx was already broadcast; return the existing txid as a no-op. + if let Some(txid) = &ch.funding_txid { + return Ok(txid.clone()); + } + } + } + let mut rpc = self.create_rpc().await?; let sign_res = rpc .call_typed(&SignpsbtRequest { @@ -277,6 +354,14 @@ impl ActionExecutor for ClnApiRpc { channel_id: String, funding_psbt: String, ) -> anyhow::Result<()> { + // Idempotency: check if channel still exists. + if !self.is_channel_alive(&channel_id).await.unwrap_or(false) { + // Channel already gone — no-op. + // TODO: Belt-and-suspenders: scan listpeerchannels for + // orphaned withheld channels not claimed by any session. + return Ok(()); + } + let close_res = { let mut rpc = self.create_rpc().await?; rpc.call_typed(&CloseRequest { @@ -360,6 +445,8 @@ impl DatastoreProvider for ClnApiRpc { funding_txid: Option, #[serde(skip_serializing_if = "Option::is_none")] preimage: Option, + #[serde(skip_serializing_if = "Option::is_none")] + forwards_updated_index: &'a Option, } let ds = BorrowedDatastoreEntry { @@ -372,6 +459,7 @@ impl DatastoreProvider for ClnApiRpc { funding_psbt: None, funding_txid: None, preimage: None, + forwards_updated_index: &None, }; let json_str = serde_json::to_string(&ds)?; @@ -557,6 +645,83 @@ impl DatastoreProvider for ClnApiRpc { .with_context(|| "calling datastore for update_session_preimage")?; Ok(()) } + + async fn list_active_sessions(&self) -> Result> { + let mut rpc = self.create_rpc().await?; + let prefix = vec![ + DS_MAIN_KEY.to_string(), + DS_SUB_KEY.to_string(), + DS_SESSIONS_KEY.to_string(), + DS_ACTIVE_KEY.to_string(), + ]; + let res = rpc + .call_typed(&ListdatastoreRequest { key: Some(prefix) }) + .await + .with_context(|| "calling listdatastore for list_active_sessions")?; + + let mut sessions = Vec::new(); + for ds in &res.datastore { + if let Some(scid_str) = ds.key.last() { + if let Ok(scid) = scid_str.parse::() { + let json_str = ds.string.as_deref().unwrap_or(""); + if let Ok(entry) = serde_json::from_str::(json_str) { + sessions.push((scid, entry)); + } + } + } + } + Ok(sessions) + } + + async fn update_session_forwards_index(&self, scid: &ShortChannelId, index: u64) -> Result<()> { + let mut entry = self.get_buy_request(scid).await?; + entry.forwards_updated_index = Some(index); + let json_str = serde_json::to_string(&entry)?; + + let mut rpc = self.create_rpc().await?; + rpc.call_typed(&DatastoreRequest { + generation: None, + hex: None, + mode: Some(DatastoreMode::CREATE_OR_REPLACE), + string: Some(json_str), + key: vec![ + DS_MAIN_KEY.to_string(), + DS_SUB_KEY.to_string(), + DS_SESSIONS_KEY.to_string(), + DS_ACTIVE_KEY.to_string(), + scid.to_string(), + ], + }) + .await + .with_context(|| "calling datastore for update_session_forwards_index")?; + Ok(()) + } + + async fn reset_session_funding(&self, scid: &ShortChannelId) -> Result<()> { + let mut entry = self.get_buy_request(scid).await?; + entry.channel_id = None; + entry.funding_psbt = None; + entry.funding_txid = None; + let json_str = serde_json::to_string(&entry)?; + + let mut rpc = self.create_rpc().await?; + rpc.call_typed(&DatastoreRequest { + generation: None, + hex: None, + mode: Some(DatastoreMode::CREATE_OR_REPLACE), + string: Some(json_str), + key: vec![ + DS_MAIN_KEY.to_string(), + DS_SUB_KEY.to_string(), + DS_SESSIONS_KEY.to_string(), + DS_ACTIVE_KEY.to_string(), + scid.to_string(), + ], + }) + .await + .with_context(|| "calling datastore for reset_session_funding")?; + Ok(()) + } } #[async_trait] @@ -593,6 +758,153 @@ impl BlockheightProvider for ClnApiRpc { } } +#[async_trait] +impl RecoveryProvider for ClnApiRpc { + async fn get_forward_activity(&self, channel_id: &str) -> Result { + // Check historical forwards via listforwards using out_channel filter. + let scid = match self.get_channel_scid(channel_id).await? { + Some(s) => s, + None => { + // Channel has no scid yet — no forwards possible. + return Ok(ForwardActivity::NoForwards); + } + }; + + let mut rpc = self.create_rpc().await?; + let fwd_res = rpc + .call_typed(&ListforwardsRequest { + in_channel: None, + index: Some(ListforwardsIndex::UPDATED), + limit: None, + out_channel: Some(scid), + start: None, + status: None, + }) + .await + .with_context(|| "calling listforwards in get_forward_activity")?; + + if fwd_res.forwards.is_empty() { + return Ok(ForwardActivity::NoForwards); + } + + let mut has_offered = false; + for fwd in &fwd_res.forwards { + match fwd.status { + ListforwardsForwardsStatus::SETTLED => { + return Ok(ForwardActivity::Settled); + } + ListforwardsForwardsStatus::OFFERED => { + has_offered = true; + } + ListforwardsForwardsStatus::FAILED | ListforwardsForwardsStatus::LOCAL_FAILED => {} + } + } + + if has_offered { + return Ok(ForwardActivity::Offered); + } + + // All forwards failed. + Ok(ForwardActivity::AllFailed) + } + + async fn get_channel_recovery_info(&self, channel_id: &str) -> Result { + let sha = channel_id + .parse::() + .with_context(|| format!("parsing channel_id '{channel_id}'"))?; + let mut rpc = self.create_rpc().await?; + let list_res = rpc + .call_typed(&ListpeerchannelsRequest { + channel_id: Some(sha), + id: None, + short_channel_id: None, + }) + .await + .with_context(|| "calling listpeerchannels in get_channel_recovery_info")?; + + match list_res.channels.first() { + None => Ok(ChannelRecoveryInfo { + exists: false, + withheld: false, + }), + Some(ch) => { + let withheld = ch + .funding + .as_ref() + .and_then(|f| f.withheld) + .unwrap_or(false); + Ok(ChannelRecoveryInfo { + exists: true, + withheld, + }) + } + } + } + + async fn close_and_unreserve(&self, channel_id: &str, funding_psbt: &str) -> Result<()> { + self.abandon_session(channel_id.to_string(), funding_psbt.to_string()) + .await + } + + async fn wait_for_forward_resolution( + &self, + channel_id: &str, + from_index: u64, + ) -> Result<(ForwardActivity, u64)> { + // Get the scid for this channel so we can match wait responses. + let scid = self.get_channel_scid(channel_id).await?; + + let mut next_index = from_index + 1; + loop { + let mut rpc = self.create_rpc().await?; + let wait_res = rpc + .call_typed(&WaitRequest { + subsystem: WaitSubsystem::FORWARDS, + indexname: WaitIndexname::UPDATED, + nextvalue: next_index, + }) + .await + .with_context(|| { + format!("calling wait for channel_id={channel_id} at index={next_index}") + })?; + + let new_index = wait_res.updated.unwrap_or(next_index); + + // Check if this update is for our channel. + let is_our_channel = match (&scid, &wait_res.forwards) { + (Some(our_scid), Some(fwd)) => fwd + .out_channel + .as_ref() + .map(|c| c == our_scid) + .unwrap_or(false), + _ => false, + }; + + if is_our_channel { + if let Some(fwd) = &wait_res.forwards { + match fwd.status { + Some(WaitForwardsStatus::SETTLED) => { + return Ok((ForwardActivity::Settled, new_index)); + } + Some(WaitForwardsStatus::OFFERED) => { + return Ok((ForwardActivity::Offered, new_index)); + } + Some(WaitForwardsStatus::FAILED) + | Some(WaitForwardsStatus::LOCAL_FAILED) => { + // Check full history to decide AllFailed vs Active. + let activity = self.get_forward_activity(channel_id).await?; + return Ok((activity, new_index)); + } + None => {} + } + } + } + + next_index = new_index + 1; + } + } +} + #[derive(Debug)] pub enum DsError { /// No datastore entry with this exact key. diff --git a/plugins/lsps-plugin/src/core/lsps2/actor.rs b/plugins/lsps-plugin/src/core/lsps2/actor.rs index 3df39c45a5d7..344e3812d51d 100644 --- a/plugins/lsps-plugin/src/core/lsps2/actor.rs +++ b/plugins/lsps-plugin/src/core/lsps2/actor.rs @@ -1,11 +1,11 @@ use crate::{ core::lsps2::{ - provider::DatastoreProvider, + provider::{DatastoreProvider, ForwardActivity, RecoveryProvider}, session::{PaymentPart, Session, SessionAction, SessionInput}, }, proto::{ lsps0::{Msat, ShortChannelId}, - lsps2::OpeningFeeParams, + lsps2::{DatastoreEntry, OpeningFeeParams}, }, }; use anyhow::Result; @@ -41,8 +41,13 @@ enum ActorInput { funding_psbt: String, }, FundingFailed, - PaymentSettled { preimage: Option }, - PaymentFailed, + PaymentSettled { + preimage: Option, + updated_index: Option, + }, + PaymentFailed { + updated_index: Option, + }, FundingBroadcasted, NewBlock { height: u32, @@ -60,6 +65,7 @@ pub trait ActionExecutor { peer_id: String, channel_capacity_msat: Msat, opening_fee_params: OpeningFeeParams, + scid: ShortChannelId, ) -> Result<(String, String)>; async fn abandon_session(&self, channel_id: String, funding_psbt: String) -> Result<()>; @@ -83,12 +89,25 @@ impl ActorInboxHandle { Ok(rx.await?) } - pub async fn payment_settled(&self, preimage: Option) -> Result<()> { - Ok(self.tx.send(ActorInput::PaymentSettled { preimage }).await?) + pub async fn payment_settled( + &self, + preimage: Option, + updated_index: Option, + ) -> Result<()> { + Ok(self + .tx + .send(ActorInput::PaymentSettled { + preimage, + updated_index, + }) + .await?) } - pub async fn payment_failed(&self) -> Result<()> { - Ok(self.tx.send(ActorInput::PaymentFailed).await?) + pub async fn payment_failed(&self, updated_index: Option) -> Result<()> { + Ok(self + .tx + .send(ActorInput::PaymentFailed { updated_index }) + .await?) } pub async fn new_block(&self, height: u32) -> Result<()> { @@ -145,6 +164,42 @@ impl, + channel_id: String, + executor: A, + scid: ShortChannelId, + datastore: D, + recovery: Arc, + forwards_updated_index: Option, + ) -> ActorInboxHandle { + let (tx, inbox) = mpsc::channel(128); + let handle = ActorInboxHandle { tx: tx.clone() }; + + let actor = SessionActor { + session, + inbox, + pending_htlcs: HashMap::new(), + collect_timeout_handle: None, + channel_poll_handle: None, + self_send: tx, + executor, + peer_id: String::new(), + collect_timeout_secs: 0, + scid, + datastore, + }; + + tokio::spawn(actor.run_recovered( + initial_actions, + channel_id, + recovery, + forwards_updated_index, + )); + handle + } + fn start_collect_timeout(&mut self) { let tx = self.self_send.clone(); let timeout = Duration::from_secs(self.collect_timeout_secs); @@ -206,20 +261,40 @@ impl SessionInput::FundingFailed, - ActorInput::PaymentSettled { preimage } => { + ActorInput::PaymentSettled { + preimage, + updated_index, + } => { + if let Some(index) = updated_index { + if let Err(e) = self + .datastore + .update_session_forwards_index(&self.scid, index) + .await + { + warn!("update_session_forwards_index failed: {e}"); + } + } if let Some(ref pre) = preimage { - let datastore = self.datastore.clone(); - let scid = self.scid; - let pre = pre.clone(); - tokio::spawn(async move { - if let Err(e) = datastore.update_session_preimage(&scid, &pre).await { - warn!("update_session_preimage failed for scid={scid}: {e}"); - } - }); + if let Err(e) = + self.datastore.update_session_preimage(&self.scid, pre).await + { + warn!("update_session_preimage failed for scid={}: {e}", self.scid); + } } SessionInput::PaymentSettled } - ActorInput::PaymentFailed => SessionInput::PaymentFailed, + ActorInput::PaymentFailed { updated_index } => { + if let Some(index) = updated_index { + if let Err(e) = self + .datastore + .update_session_forwards_index(&self.scid, index) + .await + { + warn!("update_session_forwards_index failed: {e}"); + } + } + SessionInput::PaymentFailed + } ActorInput::FundingBroadcasted => SessionInput::FundingBroadcasted, ActorInput::NewBlock { height } => SessionInput::NewBlock { height }, ActorInput::ChannelClosed { channel_id } => { @@ -255,10 +330,159 @@ impl, + channel_id: String, + recovery: Arc, + forwards_updated_index: Option, + ) { + // Execute initial actions (e.g., BroadcastFundingTx for Broadcasting state) + for action in initial_actions { + self.execute_action(action); + } + + if self.session.is_terminal() { + Self::finalize(&self.session, &self.datastore, self.scid).await; + return; + } + + // Start forward monitoring + let from_index = forwards_updated_index.unwrap_or(0); + let self_tx = self.self_send.clone(); + let monitor_handle = { + let recovery = recovery.clone(); + let channel_id = channel_id.clone(); + let datastore = self.datastore.clone(); + let scid = self.scid; + + tokio::spawn(async move { + // First: check listforwards for already-settled forwards + match recovery.get_forward_activity(&channel_id).await { + Ok(ForwardActivity::Settled) => { + let _ = self_tx + .send(ActorInput::PaymentSettled { + preimage: None, + updated_index: None, + }) + .await; + return; + } + Ok(ForwardActivity::AllFailed) => { + let _ = self_tx + .send(ActorInput::PaymentFailed { updated_index: None }) + .await; + return; + } + Ok(ForwardActivity::Offered) + | Ok(ForwardActivity::NoForwards) + | Err(_) => { + // Fall through to wait loop + } + } + + // Poll using wait subsystem + let mut current_index = from_index; + loop { + match recovery + .wait_for_forward_resolution(&channel_id, current_index) + .await + { + Ok((ForwardActivity::Settled, new_index)) => { + let _ = + datastore.update_session_forwards_index(&scid, new_index).await; + let _ = self_tx + .send(ActorInput::PaymentSettled { + preimage: None, + updated_index: None, + }) + .await; + return; + } + Ok((ForwardActivity::AllFailed, new_index)) => { + let _ = + datastore.update_session_forwards_index(&scid, new_index).await; + let _ = self_tx + .send(ActorInput::PaymentFailed { updated_index: None }) + .await; + return; + } + Ok((ForwardActivity::Offered, new_index)) + | Ok((ForwardActivity::NoForwards, new_index)) => { + current_index = new_index; + continue; + } + Err(e) => { + warn!("forward monitoring error for scid={scid}: {e}"); + tokio::time::sleep(Duration::from_secs(5)).await; + continue; + } + } + } + }) + }; + + // Main loop: process inbox events + loop { + match self.inbox.recv().await { + Some(actor_input) => { + let session_input = match actor_input { + ActorInput::PaymentSettled { + preimage, + updated_index: _, + } => { + if let Some(ref pre) = preimage { + let datastore = self.datastore.clone(); + let scid = self.scid; + let pre = pre.clone(); + tokio::spawn(async move { + if let Err(e) = + datastore.update_session_preimage(&scid, &pre).await + { + warn!("update_session_preimage failed: {e}"); + } + }); + } + SessionInput::PaymentSettled + } + ActorInput::PaymentFailed { updated_index: _ } => { + SessionInput::PaymentFailed + } + ActorInput::FundingBroadcasted => SessionInput::FundingBroadcasted, + _ => continue, + }; + + match self.session.apply(session_input) { + Ok(result) => { + for action in result.actions { + self.execute_action(action); + } + } + Err(e) => { + warn!("FSM error in recovered session: {e}"); + break; + } + } - if let Some(outcome) = self.session.outcome() { - if let Err(e) = self.datastore.finalize_session(&self.scid, outcome).await { - warn!("finalize_session failed for scid={}: {e}", self.scid); + if self.session.is_terminal() { + break; + } + } + None => break, + } + } + + monitor_handle.abort(); + Self::finalize(&self.session, &self.datastore, self.scid).await; + } + + async fn finalize(session: &Session, datastore: &D, scid: ShortChannelId) { + if let Some(outcome) = session.outcome() { + if let Err(e) = datastore.finalize_session(&scid, outcome).await { + warn!("finalize_session failed for scid={scid}: {e}"); } } } @@ -292,20 +516,13 @@ impl { let executor = self.executor.clone(); let self_tx = self.self_send.clone(); - let datastore = self.datastore.clone(); let scid = self.scid; tokio::spawn(async move { match executor - .fund_channel(peer_id, channel_capacity_msat, opening_fee_params) + .fund_channel(peer_id, channel_capacity_msat, opening_fee_params, scid) .await { Ok((channel_id, funding_psbt)) => { - if let Err(e) = datastore - .update_session_funding(&scid, &channel_id, &funding_psbt) - .await - { - warn!("update_session_funding failed for scid={scid}: {e}"); - } let _ = self_tx .send(ActorInput::ChannelReady { channel_id, @@ -407,9 +624,10 @@ impl ActionExecutor for Arc { peer_id: String, channel_capacity_msat: Msat, opening_fee_params: OpeningFeeParams, + scid: ShortChannelId, ) -> Result<(String, String)> { (**self) - .fund_channel(peer_id, channel_capacity_msat, opening_fee_params) + .fund_channel(peer_id, channel_capacity_msat, opening_fee_params, scid) .await } @@ -492,4 +710,20 @@ impl DatastoreProvider for Arc { ) -> Result<()> { (**self).update_session_preimage(scid, preimage).await } + + async fn list_active_sessions(&self) -> Result> { + (**self).list_active_sessions().await + } + + async fn update_session_forwards_index( + &self, + scid: &ShortChannelId, + index: u64, + ) -> Result<()> { + (**self).update_session_forwards_index(scid, index).await + } + + async fn reset_session_funding(&self, scid: &ShortChannelId) -> Result<()> { + (**self).reset_session_funding(scid).await + } } diff --git a/plugins/lsps-plugin/src/core/lsps2/manager.rs b/plugins/lsps-plugin/src/core/lsps2/manager.rs index 6ef427405479..a3dcb1e0d166 100644 --- a/plugins/lsps-plugin/src/core/lsps2/manager.rs +++ b/plugins/lsps-plugin/src/core/lsps2/manager.rs @@ -1,10 +1,12 @@ use super::actor::{ActionExecutor, ActorInboxHandle, HtlcResponse}; -use super::provider::DatastoreProvider; +use super::provider::{DatastoreProvider, ForwardActivity, RecoveryProvider}; use super::session::{PaymentPart, Session}; use crate::core::lsps2::actor::SessionActor; use crate::proto::lsps0::ShortChannelId; +use crate::proto::lsps2::SessionOutcome; pub use bitcoin::hashes::sha256::Hash as PaymentHash; -use log::debug; +use chrono::Utc; +use log::{debug, warn}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::Mutex; @@ -33,6 +35,7 @@ impl Default for SessionConfig { pub struct SessionManager { sessions: Mutex>, + recovery_handles: Mutex>, datastore: Arc, executor: Arc, config: SessionConfig, @@ -44,12 +47,87 @@ impl pub fn new(datastore: Arc, executor: Arc, config: SessionConfig) -> Self { Self { sessions: Mutex::new(HashMap::new()), + recovery_handles: Mutex::new(Vec::new()), datastore, executor, config, } } + pub async fn recover(&self, recovery: Arc) -> anyhow::Result<()> { + let entries = self.datastore.list_active_sessions().await?; + + for (scid, entry) in entries { + match (&entry.channel_id, &entry.funding_psbt) { + (None, _) => { + if entry.opening_fee_params.valid_until < Utc::now() { + self.datastore + .finalize_session(&scid, SessionOutcome::Timeout) + .await?; + } + } + + (Some(channel_id), Some(funding_psbt)) => { + let info = recovery.get_channel_recovery_info(channel_id).await?; + + if !info.exists { + self.datastore + .finalize_session(&scid, SessionOutcome::Abandoned) + .await?; + continue; + } + + let activity = recovery.get_forward_activity(channel_id).await?; + + match activity { + ForwardActivity::NoForwards => { + recovery + .close_and_unreserve(channel_id, funding_psbt) + .await?; + self.datastore.reset_session_funding(&scid).await?; + } + ForwardActivity::AllFailed => { + self.datastore + .finalize_session(&scid, SessionOutcome::Abandoned) + .await?; + } + ForwardActivity::Offered | ForwardActivity::Settled => { + let (session, initial_actions) = Session::recover( + channel_id.clone(), + funding_psbt.clone(), + entry.preimage.clone(), + entry.opening_fee_params.clone(), + ); + + let handle = + SessionActor::spawn_recovered_session_actor( + session, + initial_actions, + channel_id.clone(), + self.executor.clone(), + scid, + self.datastore.clone(), + recovery.clone(), + entry.forwards_updated_index, + ); + + self.recovery_handles.lock().await.push(handle); + } + } + } + + _ => { + warn!("inconsistent datastore entry for scid={scid}, finalizing as Failed"); + self.datastore + .finalize_session(&scid, SessionOutcome::Failed) + .await?; + } + } + } + + Ok(()) + } + pub async fn on_part( &self, payment_hash: PaymentHash, @@ -80,11 +158,12 @@ impl &self, payment_hash: PaymentHash, preimage: Option, + updated_index: Option, ) -> Result<(), ManagerError> { let handle = { - let sessions = self.sessions.lock().await; - match sessions.get(&payment_hash) { - Some(handle) => handle.clone(), + let mut sessions = self.sessions.lock().await; + match sessions.remove(&payment_hash) { + Some(handle) => handle, None => { debug!("on_payment_settled: no session for {payment_hash}"); return Ok(()); @@ -92,23 +171,21 @@ impl } }; - match handle.payment_settled(preimage).await { + match handle.payment_settled(preimage, updated_index).await { Ok(()) => Ok(()), - Err(_) => { - self.sessions.lock().await.remove(&payment_hash); - Err(ManagerError::SessionTerminated) - } + Err(_) => Err(ManagerError::SessionTerminated), } } pub async fn on_payment_failed( &self, payment_hash: PaymentHash, + updated_index: Option, ) -> Result<(), ManagerError> { let handle = { - let sessions = self.sessions.lock().await; - match sessions.get(&payment_hash) { - Some(handle) => handle.clone(), + let mut sessions = self.sessions.lock().await; + match sessions.remove(&payment_hash) { + Some(handle) => handle, None => { debug!("on_payment_failed: no session for {payment_hash}"); return Ok(()); @@ -116,12 +193,9 @@ impl } }; - match handle.payment_failed().await { + match handle.payment_failed(updated_index).await { Ok(()) => Ok(()), - Err(_) => { - self.sessions.lock().await.remove(&payment_hash); - Err(ManagerError::SessionTerminated) - } + Err(_) => Err(ManagerError::SessionTerminated), } } @@ -183,6 +257,7 @@ impl #[cfg(test)] mod tests { use super::*; + use crate::core::lsps2::provider::{ChannelRecoveryInfo, ForwardActivity, RecoveryProvider}; use crate::proto::lsps0::{Msat, Ppm}; use crate::proto::lsps2::{DatastoreEntry, OpeningFeeParams, Promise, SessionOutcome}; use async_trait::async_trait; @@ -237,6 +312,7 @@ mod tests { funding_psbt: None, funding_txid: None, preimage: None, + forwards_updated_index: None, } } @@ -317,6 +393,24 @@ mod tests { ) -> anyhow::Result<()> { Ok(()) } + + async fn list_active_sessions(&self) -> anyhow::Result> { + Ok(self.entries.iter().map(|(k, v)| { + (k.parse::().unwrap(), v.clone()) + }).collect()) + } + + async fn update_session_forwards_index( + &self, + _scid: &ShortChannelId, + _index: u64, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn reset_session_funding(&self, _scid: &ShortChannelId) -> anyhow::Result<()> { + Ok(()) + } } struct MockExecutor { @@ -330,6 +424,7 @@ mod tests { _peer_id: String, _channel_capacity_msat: Msat, _opening_fee_params: OpeningFeeParams, + _scid: ShortChannelId, ) -> anyhow::Result<(String, String)> { if self.fund_succeeds { Ok(("channel-id-1".to_string(), "psbt-1".to_string())) @@ -473,7 +568,7 @@ mod tests { #[tokio::test(flavor = "current_thread", start_paused = true)] async fn payment_settled_unknown_hash_is_ok() { let mgr = test_manager(true); - let result = mgr.on_payment_settled(test_payment_hash(99), None).await; + let result = mgr.on_payment_settled(test_payment_hash(99), None, None).await; assert!(result.is_ok()); } @@ -490,7 +585,7 @@ mod tests { assert!(matches!(resp, HtlcResponse::Forward { .. })); // Settle payment — session is in AwaitingSettlement. - let result = mgr.on_payment_settled(hash, None).await; + let result = mgr.on_payment_settled(hash, None, None).await; assert!(result.is_ok()); } @@ -512,7 +607,7 @@ mod tests { assert_eq!(mgr.session_count().await, 1); // on_payment_settled hits dead handle → removes entry. - let err = mgr.on_payment_settled(hash, None).await.unwrap_err(); + let err = mgr.on_payment_settled(hash, None, None).await.unwrap_err(); assert!(matches!(err, ManagerError::SessionTerminated { .. })); assert_eq!(mgr.session_count().await, 0); } @@ -520,7 +615,7 @@ mod tests { #[tokio::test(flavor = "current_thread", start_paused = true)] async fn payment_failed_unknown_hash_is_ok() { let mgr = test_manager(true); - let result = mgr.on_payment_failed(test_payment_hash(99)).await; + let result = mgr.on_payment_failed(test_payment_hash(99), None).await; assert!(result.is_ok()); } @@ -537,7 +632,7 @@ mod tests { assert!(matches!(resp, HtlcResponse::Forward { .. })); // Fail payment — session is in AwaitingSettlement. - let result = mgr.on_payment_failed(hash).await; + let result = mgr.on_payment_failed(hash, None).await; assert!(result.is_ok()); } @@ -560,4 +655,165 @@ mod tests { assert!(matches!(r2, HtlcResponse::Forward { .. })); assert_eq!(mgr.session_count().await, 1); } + + struct MockRecoveryProvider { + channel_exists: bool, + forward_activity: ForwardActivity, + } + + impl Default for MockRecoveryProvider { + fn default() -> Self { + Self { + channel_exists: false, + forward_activity: ForwardActivity::NoForwards, + } + } + } + + #[async_trait] + impl RecoveryProvider for MockRecoveryProvider { + async fn get_forward_activity( + &self, + _channel_id: &str, + ) -> anyhow::Result { + Ok(self.forward_activity.clone()) + } + async fn get_channel_recovery_info( + &self, + _channel_id: &str, + ) -> anyhow::Result { + Ok(ChannelRecoveryInfo { + exists: self.channel_exists, + withheld: true, + }) + } + async fn close_and_unreserve( + &self, + _channel_id: &str, + _funding_psbt: &str, + ) -> anyhow::Result<()> { + Ok(()) + } + async fn wait_for_forward_resolution( + &self, + _channel_id: &str, + from_index: u64, + ) -> anyhow::Result<(ForwardActivity, u64)> { + Ok((self.forward_activity.clone(), from_index + 1)) + } + } + + #[tokio::test] + async fn recover_pre_funding_expired_finalizes_as_timeout() { + let mut ds = MockDatastore::new(); + // Clear default entries, add one with expired opening_fee_params. + ds.entries.clear(); + let mut entry = test_datastore_entry(); + entry.opening_fee_params.valid_until = Utc::now() - ChronoDuration::hours(1); + ds.entries.insert(test_scid().to_string(), entry); + + let mgr = Arc::new(SessionManager::new( + Arc::new(ds), + Arc::new(MockExecutor { fund_succeeds: true }), + SessionConfig::default(), + )); + + mgr.recover(Arc::new(MockRecoveryProvider::default())) + .await + .unwrap(); + assert_eq!(mgr.session_count().await, 0); + } + + #[tokio::test] + async fn recover_pre_funding_valid_leaves_session_for_replay() { + let ds = MockDatastore::new(); // entries have valid_until in future + let mgr = Arc::new(SessionManager::new( + Arc::new(ds), + Arc::new(MockExecutor { fund_succeeds: true }), + SessionConfig::default(), + )); + + mgr.recover(Arc::new(MockRecoveryProvider::default())).await.unwrap(); + assert_eq!(mgr.session_count().await, 0); + + // Replayed HTLC should still create a fresh session + let _response = mgr.on_part( + test_payment_hash(1), + test_scid(), + part(1, 1_000), + ).await.unwrap(); + assert_eq!(mgr.session_count().await, 1); + } + + #[tokio::test] + async fn recover_funded_channel_gone_finalizes_abandoned() { + let mut ds = MockDatastore::new(); + ds.entries.clear(); + let mut entry = test_datastore_entry(); + entry.channel_id = Some("channel-gone".to_string()); + entry.funding_psbt = Some("psbt-1".to_string()); + ds.entries.insert(test_scid().to_string(), entry); + + let mgr = Arc::new(SessionManager::new( + Arc::new(ds), + Arc::new(MockExecutor { fund_succeeds: true }), + SessionConfig::default(), + )); + + let recovery = Arc::new(MockRecoveryProvider { + channel_exists: false, + forward_activity: ForwardActivity::NoForwards, + }); + + mgr.recover(recovery).await.unwrap(); + assert_eq!(mgr.session_count().await, 0); + } + + #[tokio::test] + async fn recover_funded_no_forwards_resets_session() { + let mut ds = MockDatastore::new(); + ds.entries.clear(); + let mut entry = test_datastore_entry(); + entry.channel_id = Some("channel-1".to_string()); + entry.funding_psbt = Some("psbt-1".to_string()); + ds.entries.insert(test_scid().to_string(), entry); + + let mgr = Arc::new(SessionManager::new( + Arc::new(ds), + Arc::new(MockExecutor { fund_succeeds: true }), + SessionConfig::default(), + )); + + let recovery = Arc::new(MockRecoveryProvider { + channel_exists: true, + forward_activity: ForwardActivity::NoForwards, + }); + + mgr.recover(recovery).await.unwrap(); + assert_eq!(mgr.session_count().await, 0); + } + + #[tokio::test] + async fn recover_funded_all_failed_finalizes_abandoned() { + let mut ds = MockDatastore::new(); + ds.entries.clear(); + let mut entry = test_datastore_entry(); + entry.channel_id = Some("channel-1".to_string()); + entry.funding_psbt = Some("psbt-1".to_string()); + ds.entries.insert(test_scid().to_string(), entry); + + let mgr = Arc::new(SessionManager::new( + Arc::new(ds), + Arc::new(MockExecutor { fund_succeeds: true }), + SessionConfig::default(), + )); + + let recovery = Arc::new(MockRecoveryProvider { + channel_exists: true, + forward_activity: ForwardActivity::AllFailed, + }); + + mgr.recover(recovery).await.unwrap(); + assert_eq!(mgr.session_count().await, 0); + } } diff --git a/plugins/lsps-plugin/src/core/lsps2/provider.rs b/plugins/lsps-plugin/src/core/lsps2/provider.rs index 2d3b5c3ce2d2..ff77d607294d 100644 --- a/plugins/lsps-plugin/src/core/lsps2/provider.rs +++ b/plugins/lsps-plugin/src/core/lsps2/provider.rs @@ -47,6 +47,63 @@ pub trait DatastoreProvider: Send + Sync { ) -> Result<()>; async fn update_session_preimage(&self, scid: &ShortChannelId, preimage: &str) -> Result<()>; + + /// List all active session entries (for recovery scan). + async fn list_active_sessions(&self) -> Result>; + + /// Update the forwards_updated_index for a session. + async fn update_session_forwards_index( + &self, + scid: &ShortChannelId, + index: u64, + ) -> Result<()>; + + /// Reset a session's funding fields back to None (for clean restart). + async fn reset_session_funding(&self, scid: &ShortChannelId) -> Result<()>; +} + +/// Status of forwards on a channel, used during recovery classification. +#[derive(Debug, Clone, PartialEq)] +pub enum ForwardActivity { + /// No forwards ever happened on this channel. + NoForwards, + /// All forwards failed (none settled or offered). + AllFailed, + /// Some forwards are in-flight (OFFERED) but none have settled yet. + Offered, + /// At least one forward has settled. + Settled, +} + +/// Information about a channel needed for recovery classification. +#[derive(Debug, Clone)] +pub struct ChannelRecoveryInfo { + pub exists: bool, + pub withheld: bool, +} + +/// Provides recovery-specific queries. Separated from ActionExecutor +/// to keep the normal operation interface clean. +#[async_trait] +pub trait RecoveryProvider: Send + Sync { + /// Check forward activity on a channel using both in-flight HTLCs + /// and historical forwards. + async fn get_forward_activity(&self, channel_id: &str) -> Result; + + /// Get channel recovery info (exists, withheld status). + async fn get_channel_recovery_info(&self, channel_id: &str) -> Result; + + /// Close a channel and unreserve its inputs. + async fn close_and_unreserve(&self, channel_id: &str, funding_psbt: &str) -> Result<()>; + + /// Monitor forward status changes using the wait subsystem. + /// Returns when a forward on the given channel settles or fails. + /// `from_index` is the last processed updated_index. + async fn wait_for_forward_resolution( + &self, + channel_id: &str, + from_index: u64, + ) -> Result<(ForwardActivity, u64)>; } #[async_trait] diff --git a/plugins/lsps-plugin/src/core/lsps2/service.rs b/plugins/lsps-plugin/src/core/lsps2/service.rs index 829a6e7ae7ec..d8a479e4f968 100644 --- a/plugins/lsps-plugin/src/core/lsps2/service.rs +++ b/plugins/lsps-plugin/src/core/lsps2/service.rs @@ -400,6 +400,22 @@ mod tests { ) -> AnyResult<()> { unimplemented!("not needed for service tests") } + + async fn list_active_sessions(&self) -> AnyResult> { + unimplemented!("not needed for service tests") + } + + async fn update_session_forwards_index( + &self, + _scid: &ShortChannelId, + _index: u64, + ) -> AnyResult<()> { + unimplemented!("not needed for service tests") + } + + async fn reset_session_funding(&self, _scid: &ShortChannelId) -> AnyResult<()> { + unimplemented!("not needed for service tests") + } } fn handler(api: MockApi) -> Lsps2ServiceHandler { diff --git a/plugins/lsps-plugin/src/core/lsps2/session.rs b/plugins/lsps-plugin/src/core/lsps2/session.rs index 3d59012fc939..38952d55e88d 100644 --- a/plugins/lsps-plugin/src/core/lsps2/session.rs +++ b/plugins/lsps-plugin/src/core/lsps2/session.rs @@ -243,6 +243,58 @@ impl Session { } } + /// Reconstruct a session from persisted state for crash recovery. + /// + /// Initializes the FSM in the appropriate state based on whether a + /// preimage was already captured: + /// - `preimage: None` → `AwaitingSettlement` (waiting for payment outcome) + /// - `preimage: Some` → `Broadcasting` (payment settled, need to broadcast) + /// + /// Forwarded HTLC parts are not reconstructed — CLN manages those + /// independently. The FSM only needs channel identity to drive + /// remaining actions. + pub fn recover( + channel_id: String, + funding_psbt: String, + preimage: Option, + opening_fee_params: OpeningFeeParams, + ) -> (Self, Vec) { + let (state, actions) = if preimage.is_some() { + ( + SessionState::Broadcasting { + channel_id: channel_id.clone(), + funding_psbt: funding_psbt.clone(), + }, + vec![SessionAction::BroadcastFundingTx { + channel_id, + funding_psbt, + }], + ) + } else { + ( + SessionState::AwaitingSettlement { + forwarded_parts: vec![], + forwarded_amount_msat: 0, + deducted_fee_msat: 0, + channel_id, + funding_psbt, + }, + vec![], + ) + }; + + let session = Self { + state, + max_parts: 0, + opening_fee_params, + payment_size_msat: None, + channel_capacity_msat: Msat::from_msat(0), + peer_id: String::new(), + }; + + (session, actions) + } + pub fn is_terminal(&self) -> bool { matches!( self.state, @@ -1975,4 +2027,78 @@ mod tests { assert!(matches!(&res.events[0], SessionEvent::UnusualInput { .. })); } } + + #[test] + fn recover_without_preimage_enters_awaiting_settlement() { + let (session, actions) = Session::recover( + "channel-id-1".to_string(), + "psbt-1".to_string(), + None, + opening_fee_params(1_000, 0), + ); + assert!(actions.is_empty()); + assert!(!session.is_terminal()); + } + + #[test] + fn recover_with_preimage_enters_broadcasting() { + let (session, actions) = Session::recover( + "channel-id-1".to_string(), + "psbt-1".to_string(), + Some("preimage-1".to_string()), + opening_fee_params(1_000, 0), + ); + assert_eq!(actions.len(), 1); + assert!(matches!( + &actions[0], + SessionAction::BroadcastFundingTx { channel_id, funding_psbt } + if channel_id == "channel-id-1" && funding_psbt == "psbt-1" + )); + assert!(!session.is_terminal()); + } + + #[test] + fn recovered_awaiting_settlement_transitions_on_payment_settled() { + let (mut session, _) = Session::recover( + "channel-id-1".to_string(), + "psbt-1".to_string(), + None, + opening_fee_params(1_000, 0), + ); + let result = session.apply(SessionInput::PaymentSettled).unwrap(); + assert!(matches!( + result.actions.as_slice(), + [SessionAction::BroadcastFundingTx { .. }] + )); + } + + #[test] + fn recovered_awaiting_settlement_transitions_on_payment_failed() { + let (mut session, _) = Session::recover( + "channel-id-1".to_string(), + "psbt-1".to_string(), + None, + opening_fee_params(1_000, 0), + ); + let result = session.apply(SessionInput::PaymentFailed).unwrap(); + assert!(matches!( + result.actions.as_slice(), + [SessionAction::AbandonSession { .. }, SessionAction::Disconnect] + )); + assert!(session.is_terminal()); + } + + #[test] + fn recovered_broadcasting_transitions_on_funding_broadcasted() { + let (mut session, _) = Session::recover( + "channel-id-1".to_string(), + "psbt-1".to_string(), + Some("preimage-1".to_string()), + opening_fee_params(1_000, 0), + ); + let result = session.apply(SessionInput::FundingBroadcasted).unwrap(); + let _ = result; + assert!(session.is_terminal()); + assert_eq!(session.outcome(), Some(SessionOutcome::Succeeded)); + } } diff --git a/plugins/lsps-plugin/src/proto/lsps2.rs b/plugins/lsps-plugin/src/proto/lsps2.rs index 1a042545b162..e8b99315f953 100644 --- a/plugins/lsps-plugin/src/proto/lsps2.rs +++ b/plugins/lsps-plugin/src/proto/lsps2.rs @@ -352,6 +352,9 @@ pub struct DatastoreEntry { #[serde(skip_serializing_if = "Option::is_none")] #[serde(default)] pub preimage: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(default)] + pub forwards_updated_index: Option, } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] diff --git a/plugins/lsps-plugin/src/service.rs b/plugins/lsps-plugin/src/service.rs index 726613565f06..4f31c9bcf322 100644 --- a/plugins/lsps-plugin/src/service.rs +++ b/plugins/lsps-plugin/src/service.rs @@ -10,7 +10,7 @@ use cln_lsps::{ lsps2::{ actor::HtlcResponse, manager::{PaymentHash, SessionConfig, SessionManager}, - provider::DatastoreProvider, + provider::{DatastoreProvider, RecoveryProvider}, session::PaymentPart, service::Lsps2ServiceHandler, }, @@ -23,7 +23,7 @@ use cln_lsps::{ }, }; use cln_plugin::{options, Plugin}; -use log::{debug, error, trace}; +use log::{debug, error, trace, warn}; use std::path::{Path, PathBuf}; use std::sync::Arc; @@ -147,6 +147,13 @@ async fn main() -> Result<(), anyhow::Error> { let collect_timeout_secs = plugin.option(&OPTION_COLLECT_TIMEOUT)? as u64; let state = State::new(rpc_path, &secret, collect_timeout_secs); + + // Recover in-flight sessions before processing replayed HTLCs + let recovery: Arc = state.api.clone(); + if let Err(e) = state.session_manager.recover(recovery).await { + warn!("session recovery failed: {e}"); + } + let plugin = plugin.start(state).await?; plugin.join().await } else { @@ -305,6 +312,8 @@ async fn on_forward_event( _ => return Ok(()), }; + let updated_index = event.get("updated_index").and_then(|v| v.as_u64()); + match status { Some("settled") => { let preimage = event @@ -315,7 +324,7 @@ async fn on_forward_event( if let Err(e) = p .state() .session_manager - .on_payment_settled(payment_hash, preimage) + .on_payment_settled(payment_hash, preimage, updated_index) .await { debug!("on_payment_settled error: {e:#}"); @@ -325,7 +334,7 @@ async fn on_forward_event( if let Err(e) = p .state() .session_manager - .on_payment_failed(payment_hash) + .on_payment_failed(payment_hash, updated_index) .await { debug!("on_payment_failed error: {e:#}"); diff --git a/tests/test_cln_lsps.py b/tests/test_cln_lsps.py index 9b066f0dfe1a..c220edaeb32b 100644 --- a/tests/test_cln_lsps.py +++ b/tests/test_cln_lsps.py @@ -20,7 +20,9 @@ } -def setup_lsps2_network(node_factory, bitcoind, lsp_opts=None, client_opts=None): +def setup_lsps2_network( + node_factory, bitcoind, lsp_opts=None, client_opts=None, may_reconnect=False +): """Create l1 (client), l2 (LSP), l3 (payer) with l3--l2 funded. Returns (l1, l2, l3, chanid) where chanid is the l3-l2 channel. @@ -28,12 +30,15 @@ def setup_lsps2_network(node_factory, bitcoind, lsp_opts=None, client_opts=None) opts = lsp_opts or LSP_OPTS client = client_opts or {} l1_opts = {"experimental-lsps-client": None, **client} + if may_reconnect: + l1_opts["may_reconnect"] = True + opts = {**opts, "may_reconnect": True} l1, l2, l3 = node_factory.get_nodes( 3, opts=[ l1_opts, opts, - {}, + {"may_reconnect": True} if may_reconnect else {}, ], ) @@ -654,8 +659,6 @@ def test_lsps2_session_newblock_unsafe_htlc_timeout(node_factory, bitcoind): dec, inv = buy_and_invoice(l1, l2, amt) routehint = only_one(only_one(dec["routes"])) - current_height = l3.rpc.getinfo()["blockheight"] - # Use small delay so cltv_expiry is close to current height. # The htlc_accepted hook intercepts before CLN's CLTV validation, # so the small delta is accepted by the LSPS2 plugin. @@ -776,3 +779,328 @@ def test_lsps2_session_cltv_force_close_abandoned(node_factory, bitcoind): l3.rpc.waitsendpay( dec["payment_hash"], partid=partid, groupid=1, timeout=60 ) + + +def test_lsps2_restart_collecting_htlcs_replayed(node_factory, bitcoind): + """Restart during collecting phase — replayed HTLCs create fresh session. + + Recovery path: pre-funding session in datastore → restart → CLN replays + unhandled HTLCs → new session collects and completes successfully. + """ + l1, l2, l3, chanid = setup_lsps2_network(node_factory, bitcoind, may_reconnect=True) + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) + + parts = 5 + routehint = only_one(only_one(dec["routes"])) + route_part = [ + { + "amount_msat": amt // parts, + "id": l2.info["id"], + "delay": routehint["cltv_expiry_delta"] + 6, + "channel": chanid, + }, + { + "amount_msat": amt // parts, + "id": l1.info["id"], + "delay": 6, + "channel": routehint["short_channel_id"], + }, + ] + + for partid in range( + 1, parts + ): # One part is missing to make sure we are actually in CollectingParts state + l3.rpc.sendpay( + route_part, + dec["payment_hash"], + payment_secret=inv["payment_secret"], + bolt11=inv["bolt11"], + amount_msat=f"{amt}msat", + groupid=1, + partid=partid, + ) + + # Restart l2 after 4 of 5 parts arrived — plugin has not funded a channel yet + l2.daemon.wait_for_log(r"PaymentPartAdded.*n_parts: 4") + l2.restart() + l2.connect(l3) + l2.connect(l1) + wait_for( + lambda: ( + only_one(l2.rpc.listpeerchannels(l3.info["id"])["channels"]).get("state") + == "CHANNELD_NORMAL" + ) + ) + + # CLN replays all unhandled HTLCs after restart. The recovery + replay + # should result in a successful payment regardless of how far the + # original session got. Still need to send the last part + l3.rpc.sendpay( + route_part, + dec["payment_hash"], + payment_secret=inv["payment_secret"], + bolt11=inv["bolt11"], + amount_msat=f"{amt}msat", + groupid=1, + partid=parts, + ) + res = l3.rpc.waitsendpay(dec["payment_hash"], partid=parts, groupid=1, timeout=60) + assert res["payment_preimage"] + + # l1 should have exactly one JIT channel. + chs = l1.rpc.listpeerchannels()["channels"] + assert len(chs) == 1 + + # Mine a block so the funding confirms. + bitcoind.generate_block(1) + wait_for( + lambda: ( + only_one(l1.rpc.listpeerchannels()["channels"]).get("short_channel_id") + is not None + ) + ) + + # Finalized entry should show success. + wait_for( + lambda: ( + len( + l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"])[ + "datastore" + ] + ) + > 0 + ) + ) + ds = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"]) + entry = json.loads(only_one(ds["datastore"])["string"]) + assert entry["outcome"] == "Succeeded" + + # Active entries should be empty. + active = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "active"]) + assert active["datastore"] == [] + + +def test_lsps2_restart_pre_funding_expired_finalized_timeout(node_factory, bitcoind): + """Restart with expired pre-funding session — finalized as Timeout. + + Recovery path: session valid_until has passed, no channel funded → + recovery classifies as Timeout. + """ + l1, l2, l3, chanid = setup_lsps2_network(node_factory, bitcoind) + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) + + # Tamper with the active session's valid_until so recovery sees it as + # expired. This avoids needing a short-validity policy plugin which + # conflicts with the client's 1-minute safety margin. + active = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "active"]) + ds_entry = only_one(active["datastore"]) + session = json.loads(ds_entry["string"]) + session["opening_fee_params"]["valid_until"] = "2000-01-01T00:00:00.000Z" + l2.rpc.datastore( + key=ds_entry["key"], + string=json.dumps(session), + mode="must-replace", + ) + + # Restart l2 — recovery finds expired session with no channel. + l2.restart() + + # Recovery should finalize the session as Timeout. + wait_for( + lambda: ( + len( + l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"])[ + "datastore" + ] + ) + > 0 + ) + ) + ds = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"]) + entry = json.loads(only_one(ds["datastore"])["string"]) + assert entry["outcome"] == "Timeout" + + # Active entries should be cleaned up. + active = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "active"]) + assert active["datastore"] == [] + + # No channel should exist between l1 and l2. + chs = l1.rpc.listpeerchannels(l2.info["id"])["channels"] + assert len(chs) == 0 + + +def test_lsps2_restart_awaiting_settlement_payment_completes(node_factory, bitcoind): + """Restart while HTLCs are held — recovered session settles successfully. + + Recovery path: funded session with OFFERED forwards → recover as + AwaitingSettlement → forward monitoring → payment settles → Succeeded. + """ + hold_plugin = os.path.join(os.path.dirname(__file__), "plugins/hold_htlcs.py") + l1, l2, l3, chanid = setup_lsps2_network( + node_factory, + bitcoind, + client_opts={"plugin": hold_plugin, "hold-time": 15}, + may_reconnect=True, + ) + # JIT channels can trigger bookkeeper "Unable to calculate fees" on restart. + l2.broken_log = r"Unable to calculate fees collected" + + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) + + parts = 2 + send_mpp(l3, l2.info["id"], l1.info["id"], chanid, dec, inv, amt, parts) + + # Wait for l1 to hold HTLCs (channel funded, HTLCs forwarded). + l1.daemon.wait_for_log("Holding onto an incoming htlc for 15 seconds") + + # Confirm early persistence: active session has channel_id. + wait_for( + lambda: ( + len( + l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "active"])[ + "datastore" + ] + ) + > 0 + and json.loads( + only_one( + l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "active"])[ + "datastore" + ] + )["string"] + ).get("channel_id") + is not None + ) + ) + + # Restart l2 while HTLCs are held on l1. + l2.restart() + l2.connect(l3) + l2.connect(l1) + + # Hold expires → l1 settles → recovered actor detects SETTLED → Succeeded. + res = l3.rpc.waitsendpay(dec["payment_hash"], partid=parts, groupid=1, timeout=60) + assert res["payment_preimage"] + + # l1 should have exactly one JIT channel. + chs = l1.rpc.listpeerchannels()["channels"] + assert len(chs) == 1 + + # Mine a block so the funding confirms. + bitcoind.generate_block(1) + wait_for( + lambda: ( + only_one(l1.rpc.listpeerchannels()["channels"]).get("short_channel_id") + is not None + ) + ) + + # Finalized entry should show success with funding_txid. + wait_for( + lambda: ( + len( + l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"])[ + "datastore" + ] + ) + > 0 + ) + ) + ds = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"]) + entry = json.loads(only_one(ds["datastore"])["string"]) + assert entry["outcome"] == "Succeeded" + assert isinstance(entry["funding_txid"], str) and entry["funding_txid"] + + # Active entries should be empty. + active = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "active"]) + assert active["datastore"] == [] + + +def test_lsps2_restart_awaiting_settlement_payment_fails_abandoned( + node_factory, bitcoind +): + """Restart while HTLCs are held, payment fails — session Abandoned. + + Recovery path: funded session with OFFERED forwards → recover as + AwaitingSettlement → forward monitoring → forwards fail → Abandoned. + """ + hold_plugin = os.path.join(os.path.dirname(__file__), "plugins/hold_htlcs.py") + l1, l2, l3, chanid = setup_lsps2_network( + node_factory, + bitcoind, + client_opts={"plugin": hold_plugin, "hold-time": 15}, + may_reconnect=True, + ) + # JIT channels can trigger bookkeeper "Unable to calculate fees" on restart. + l2.broken_log = r"Unable to calculate fees collected" + + amt = 10_000_000 + dec, inv = buy_and_invoice(l1, l2, amt) + + # Delete the invoice on l1 so it will reject HTLCs after hold expires. + invoices = l1.rpc.listinvoices()["invoices"] + for i in invoices: + if i["status"] == "unpaid": + l1.rpc.delinvoice(i["label"], "unpaid") + + parts = 2 + send_mpp(l3, l2.info["id"], l1.info["id"], chanid, dec, inv, amt, parts) + + # Wait for l1 to hold HTLCs (channel funded, HTLCs forwarded). + l1.daemon.wait_for_log("Holding onto an incoming htlc for 15 seconds") + + # Confirm early persistence: active session has channel_id. + wait_for( + lambda: ( + len( + l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "active"])[ + "datastore" + ] + ) + > 0 + and json.loads( + only_one( + l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "active"])[ + "datastore" + ] + )["string"] + ).get("channel_id") + is not None + ) + ) + + # Restart l2 while HTLCs are held. + l2.restart() + l2.connect(l3) + l2.connect(l1) + + # Hold expires → l1 rejects (no invoice) → forwards fail → Abandoned. + for partid in range(1, parts + 1): + with pytest.raises(Exception): + l3.rpc.waitsendpay( + dec["payment_hash"], partid=partid, groupid=1, timeout=60 + ) + + # Finalized entry should show Abandoned. + wait_for( + lambda: ( + len( + l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"])[ + "datastore" + ] + ) + > 0 + ) + ) + ds = l2.rpc.listdatastore(["lsps", "lsps2", "sessions", "finalized"]) + entry = json.loads(only_one(ds["datastore"])["string"]) + assert entry["outcome"] == "Abandoned" + + # Channel should be gone on l2. + wait_for(lambda: len(l2.rpc.listpeerchannels(l1.info["id"])["channels"]) == 0) + + # UTXOs should be unreserved. + assert not any(o["reserved"] for o in l2.rpc.listfunds()["outputs"]) From e68837d50ffcc1a36820d3ce39cf3c94735741a5 Mon Sep 17 00:00:00 2001 From: Peter Neuroth Date: Thu, 19 Mar 2026 00:11:32 +0100 Subject: [PATCH 06/11] plugins(lsps2): simplify DatastoreProvider to focused trait Reduce DatastoreProvider from many methods to 5, with the actor owning the DatastoreEntry and driving all writes through the actor loop. This makes the datastore boundary simpler and testable. --- plugins/lsps-plugin/src/cln_adapters/rpc.rs | 212 +++++------------- plugins/lsps-plugin/src/core/lsps2/actor.rs | 168 ++++++-------- plugins/lsps-plugin/src/core/lsps2/manager.rs | 66 ++---- .../lsps-plugin/src/core/lsps2/provider.rs | 31 +-- plugins/lsps-plugin/src/core/lsps2/service.rs | 77 +++---- 5 files changed, 178 insertions(+), 376 deletions(-) diff --git a/plugins/lsps-plugin/src/cln_adapters/rpc.rs b/plugins/lsps-plugin/src/cln_adapters/rpc.rs index 5abf4302c6d8..9e1aea5cdf99 100644 --- a/plugins/lsps-plugin/src/cln_adapters/rpc.rs +++ b/plugins/lsps-plugin/src/cln_adapters/rpc.rs @@ -156,6 +156,26 @@ impl ClnApiRpc { } } + async fn del_buy_request(&self, scid: &ShortChannelId) -> Result<()> { + let mut rpc = self.create_rpc().await?; + let key = vec![ + DS_MAIN_KEY.to_string(), + DS_SUB_KEY.to_string(), + DS_SESSIONS_KEY.to_string(), + DS_ACTIVE_KEY.to_string(), + scid.to_string(), + ]; + + let _ = rpc + .call_typed(&DeldatastoreRequest { + generation: None, + key, + }) + .await; + + Ok(()) + } + /// Get the short_channel_id for a channel, needed for listforwards queries. /// Falls back to alias.local for unconfirmed JIT channels. async fn get_channel_scid(&self, channel_id: &str) -> Result> { @@ -193,7 +213,7 @@ impl ActionExecutor for ClnApiRpc { peer_id: String, channel_size: Msat, _opening_fee_params: OpeningFeeParams, - scid: ShortChannelId, + _scid: ShortChannelId, ) -> anyhow::Result<(String, String)> { let pk = PublicKey::from_str(&peer_id) .with_context(|| format!("parsing peer_id '{peer_id}'"))?; @@ -276,13 +296,6 @@ impl ActionExecutor for ClnApiRpc { }; let channel_id = complete_res.channel_id; - // Early persist: close crash window between fundchannel_complete - // and actor's datastore write. If we crash after fundchannel_complete, - // the withheld channel survives restart — this ensures we know about it. - self.update_session_funding(&scid, &channel_id.to_string(), &psbt) - .await - .context("early persist of funding after fundchannel_complete")?; - if let Err(e) = self .poll_channel_ready( &channel_id, @@ -427,7 +440,8 @@ impl DatastoreProvider for ClnApiRpc { opening_fee_params: &OpeningFeeParams, expected_payment_size: &Option, channel_capacity_msat: &Msat, - ) -> Result { + ) -> Result { + let created_at = chrono::Utc::now(); let mut rpc = self.create_rpc().await?; #[derive(Serialize)] struct BorrowedDatastoreEntry<'a> { @@ -454,7 +468,7 @@ impl DatastoreProvider for ClnApiRpc { opening_fee_params, expected_payment_size, channel_capacity_msat, - created_at: chrono::Utc::now(), + created_at, channel_id: None, funding_psbt: None, funding_txid: None, @@ -483,7 +497,18 @@ impl DatastoreProvider for ClnApiRpc { .map_err(anyhow::Error::new) .with_context(|| "calling datastore")?; - Ok(true) + Ok(DatastoreEntry { + peer_id: *peer_id, + opening_fee_params: opening_fee_params.clone(), + expected_payment_size: *expected_payment_size, + channel_capacity_msat: *channel_capacity_msat, + created_at, + channel_id: None, + funding_psbt: None, + funding_txid: None, + preimage: None, + forwards_updated_index: None, + }) } async fn get_buy_request(&self, scid: &ShortChannelId) -> Result { @@ -506,23 +531,24 @@ impl DatastoreProvider for ClnApiRpc { Ok(rec) } - async fn del_buy_request(&self, scid: &ShortChannelId) -> Result<()> { + async fn save_session(&self, scid: &ShortChannelId, entry: &DatastoreEntry) -> Result<()> { + let json_str = serde_json::to_string(entry)?; let mut rpc = self.create_rpc().await?; - let key = vec![ - DS_MAIN_KEY.to_string(), - DS_SUB_KEY.to_string(), - DS_SESSIONS_KEY.to_string(), - DS_ACTIVE_KEY.to_string(), - scid.to_string(), - ]; - - let _ = rpc - .call_typed(&DeldatastoreRequest { - generation: None, - key, - }) - .await; - + rpc.call_typed(&DatastoreRequest { + generation: None, + hex: None, + mode: Some(DatastoreMode::CREATE_OR_REPLACE), + string: Some(json_str), + key: vec![ + DS_MAIN_KEY.to_string(), + DS_SUB_KEY.to_string(), + DS_SESSIONS_KEY.to_string(), + DS_ACTIVE_KEY.to_string(), + scid.to_string(), + ], + }) + .await + .with_context(|| "calling datastore for save_session")?; Ok(()) } @@ -564,88 +590,6 @@ impl DatastoreProvider for ClnApiRpc { Ok(()) } - async fn update_session_funding( - &self, - scid: &ShortChannelId, - channel_id: &str, - funding_psbt: &str, - ) -> Result<()> { - let mut entry = self.get_buy_request(scid).await?; - entry.channel_id = Some(channel_id.to_string()); - entry.funding_psbt = Some(funding_psbt.to_string()); - let json_str = serde_json::to_string(&entry)?; - - let mut rpc = self.create_rpc().await?; - rpc.call_typed(&DatastoreRequest { - generation: None, - hex: None, - mode: Some(DatastoreMode::CREATE_OR_REPLACE), - string: Some(json_str), - key: vec![ - DS_MAIN_KEY.to_string(), - DS_SUB_KEY.to_string(), - DS_SESSIONS_KEY.to_string(), - DS_ACTIVE_KEY.to_string(), - scid.to_string(), - ], - }) - .await - .with_context(|| "calling datastore for update_session_funding")?; - Ok(()) - } - - async fn update_session_funding_txid( - &self, - scid: &ShortChannelId, - funding_txid: &str, - ) -> Result<()> { - let mut entry = self.get_buy_request(scid).await?; - entry.funding_txid = Some(funding_txid.to_string()); - let json_str = serde_json::to_string(&entry)?; - - let mut rpc = self.create_rpc().await?; - rpc.call_typed(&DatastoreRequest { - generation: None, - hex: None, - mode: Some(DatastoreMode::CREATE_OR_REPLACE), - string: Some(json_str), - key: vec![ - DS_MAIN_KEY.to_string(), - DS_SUB_KEY.to_string(), - DS_SESSIONS_KEY.to_string(), - DS_ACTIVE_KEY.to_string(), - scid.to_string(), - ], - }) - .await - .with_context(|| "calling datastore for update_session_funding_txid")?; - Ok(()) - } - - async fn update_session_preimage(&self, scid: &ShortChannelId, preimage: &str) -> Result<()> { - let mut entry = self.get_buy_request(scid).await?; - entry.preimage = Some(preimage.to_string()); - let json_str = serde_json::to_string(&entry)?; - - let mut rpc = self.create_rpc().await?; - rpc.call_typed(&DatastoreRequest { - generation: None, - hex: None, - mode: Some(DatastoreMode::CREATE_OR_REPLACE), - string: Some(json_str), - key: vec![ - DS_MAIN_KEY.to_string(), - DS_SUB_KEY.to_string(), - DS_SESSIONS_KEY.to_string(), - DS_ACTIVE_KEY.to_string(), - scid.to_string(), - ], - }) - .await - .with_context(|| "calling datastore for update_session_preimage")?; - Ok(()) - } - async fn list_active_sessions(&self) -> Result> { let mut rpc = self.create_rpc().await?; let prefix = vec![ @@ -672,56 +616,6 @@ impl DatastoreProvider for ClnApiRpc { } Ok(sessions) } - - async fn update_session_forwards_index(&self, scid: &ShortChannelId, index: u64) -> Result<()> { - let mut entry = self.get_buy_request(scid).await?; - entry.forwards_updated_index = Some(index); - let json_str = serde_json::to_string(&entry)?; - - let mut rpc = self.create_rpc().await?; - rpc.call_typed(&DatastoreRequest { - generation: None, - hex: None, - mode: Some(DatastoreMode::CREATE_OR_REPLACE), - string: Some(json_str), - key: vec![ - DS_MAIN_KEY.to_string(), - DS_SUB_KEY.to_string(), - DS_SESSIONS_KEY.to_string(), - DS_ACTIVE_KEY.to_string(), - scid.to_string(), - ], - }) - .await - .with_context(|| "calling datastore for update_session_forwards_index")?; - Ok(()) - } - - async fn reset_session_funding(&self, scid: &ShortChannelId) -> Result<()> { - let mut entry = self.get_buy_request(scid).await?; - entry.channel_id = None; - entry.funding_psbt = None; - entry.funding_txid = None; - let json_str = serde_json::to_string(&entry)?; - - let mut rpc = self.create_rpc().await?; - rpc.call_typed(&DatastoreRequest { - generation: None, - hex: None, - mode: Some(DatastoreMode::CREATE_OR_REPLACE), - string: Some(json_str), - key: vec![ - DS_MAIN_KEY.to_string(), - DS_SUB_KEY.to_string(), - DS_SESSIONS_KEY.to_string(), - DS_ACTIVE_KEY.to_string(), - scid.to_string(), - ], - }) - .await - .with_context(|| "calling datastore for reset_session_funding")?; - Ok(()) - } } #[async_trait] diff --git a/plugins/lsps-plugin/src/core/lsps2/actor.rs b/plugins/lsps-plugin/src/core/lsps2/actor.rs index 344e3812d51d..0621c9d2fd56 100644 --- a/plugins/lsps-plugin/src/core/lsps2/actor.rs +++ b/plugins/lsps-plugin/src/core/lsps2/actor.rs @@ -48,7 +48,7 @@ enum ActorInput { PaymentFailed { updated_index: Option, }, - FundingBroadcasted, + FundingBroadcasted { txid: String }, NewBlock { height: u32, }, @@ -123,6 +123,7 @@ impl ActorInboxHandle { /// effects and actions. pub struct SessionActor { session: Session, + entry: DatastoreEntry, inbox: mpsc::Receiver, pending_htlcs: HashMap>, collect_timeout_handle: Option>, @@ -140,6 +141,7 @@ impl, channel_id: String, executor: A, @@ -179,6 +183,7 @@ impl SessionInput::ChannelReady { - channel_id, - funding_psbt, - }, + } => { + self.entry.channel_id = Some(channel_id.clone()); + self.entry.funding_psbt = Some(funding_psbt.clone()); + if let Err(e) = self.datastore.save_session(&self.scid, &self.entry).await { + warn!("save_session failed on ChannelReady: {e}"); + } + SessionInput::ChannelReady { + channel_id, + funding_psbt, + } + } ActorInput::FundingFailed => SessionInput::FundingFailed, ActorInput::PaymentSettled { preimage, updated_index, } => { if let Some(index) = updated_index { - if let Err(e) = self - .datastore - .update_session_forwards_index(&self.scid, index) - .await - { - warn!("update_session_forwards_index failed: {e}"); - } + self.entry.forwards_updated_index = Some(index); } if let Some(ref pre) = preimage { - if let Err(e) = - self.datastore.update_session_preimage(&self.scid, pre).await - { - warn!("update_session_preimage failed for scid={}: {e}", self.scid); + self.entry.preimage = Some(pre.clone()); + } + if updated_index.is_some() || preimage.is_some() { + if let Err(e) = self.datastore.save_session(&self.scid, &self.entry).await { + warn!("save_session failed on PaymentSettled: {e}"); } } SessionInput::PaymentSettled } ActorInput::PaymentFailed { updated_index } => { if let Some(index) = updated_index { - if let Err(e) = self - .datastore - .update_session_forwards_index(&self.scid, index) - .await - { - warn!("update_session_forwards_index failed: {e}"); + self.entry.forwards_updated_index = Some(index); + if let Err(e) = self.datastore.save_session(&self.scid, &self.entry).await { + warn!("save_session failed on PaymentFailed: {e}"); } } SessionInput::PaymentFailed } - ActorInput::FundingBroadcasted => SessionInput::FundingBroadcasted, + ActorInput::FundingBroadcasted { txid } => { + self.entry.funding_txid = Some(txid); + if let Err(e) = self.datastore.save_session(&self.scid, &self.entry).await { + warn!("save_session failed on FundingBroadcasted: {e}"); + } + SessionInput::FundingBroadcasted + } ActorInput::NewBlock { height } => SessionInput::NewBlock { height }, ActorInput::ChannelClosed { channel_id } => { SessionInput::ChannelClosed { channel_id } @@ -305,7 +315,6 @@ impl { for event in &result.events { - // Note: Add event handler later on. debug!("session event: {:?}", event); } @@ -356,7 +365,6 @@ impl { - let _ = - datastore.update_session_forwards_index(&scid, new_index).await; let _ = self_tx .send(ActorInput::PaymentSettled { preimage: None, - updated_index: None, + updated_index: Some(new_index), }) .await; return; } Ok((ForwardActivity::AllFailed, new_index)) => { - let _ = - datastore.update_session_forwards_index(&scid, new_index).await; let _ = self_tx - .send(ActorInput::PaymentFailed { updated_index: None }) + .send(ActorInput::PaymentFailed { + updated_index: Some(new_index), + }) .await; return; } @@ -432,26 +438,37 @@ impl { + if let Some(index) = updated_index { + self.entry.forwards_updated_index = Some(index); + } if let Some(ref pre) = preimage { - let datastore = self.datastore.clone(); - let scid = self.scid; - let pre = pre.clone(); - tokio::spawn(async move { - if let Err(e) = - datastore.update_session_preimage(&scid, &pre).await - { - warn!("update_session_preimage failed: {e}"); - } - }); + self.entry.preimage = Some(pre.clone()); + } + if updated_index.is_some() || preimage.is_some() { + if let Err(e) = self.datastore.save_session(&self.scid, &self.entry).await { + warn!("save_session failed on PaymentSettled: {e}"); + } } SessionInput::PaymentSettled } - ActorInput::PaymentFailed { updated_index: _ } => { + ActorInput::PaymentFailed { updated_index } => { + if let Some(index) = updated_index { + self.entry.forwards_updated_index = Some(index); + if let Err(e) = self.datastore.save_session(&self.scid, &self.entry).await { + warn!("save_session failed on PaymentFailed: {e}"); + } + } SessionInput::PaymentFailed } - ActorInput::FundingBroadcasted => SessionInput::FundingBroadcasted, + ActorInput::FundingBroadcasted { txid } => { + self.entry.funding_txid = Some(txid); + if let Err(e) = self.datastore.save_session(&self.scid, &self.entry).await { + warn!("save_session failed on FundingBroadcasted: {e}"); + } + SessionInput::FundingBroadcasted + } _ => continue, }; @@ -571,21 +588,13 @@ impl { - if let Err(e) = datastore - .update_session_funding_txid(&scid, &txid) - .await - { - warn!("update_session_funding_txid failed for scid={scid}: {e}"); - } - let _ = self_tx.send(ActorInput::FundingBroadcasted).await; + let _ = self_tx.send(ActorInput::FundingBroadcasted { txid }).await; } Err(e) => { warn!( @@ -657,7 +666,7 @@ impl DatastoreProvider for Arc { offer: &OpeningFeeParams, expected_payment_size: &Option, channel_capacity_msat: &Msat, - ) -> Result { + ) -> Result { (**self) .store_buy_request(scid, peer_id, offer, expected_payment_size, channel_capacity_msat) .await @@ -666,64 +675,27 @@ impl DatastoreProvider for Arc { async fn get_buy_request( &self, scid: &ShortChannelId, - ) -> Result { + ) -> Result { (**self).get_buy_request(scid).await } - async fn del_buy_request(&self, scid: &ShortChannelId) -> Result<()> { - (**self).del_buy_request(scid).await - } - - async fn finalize_session( + async fn save_session( &self, scid: &ShortChannelId, - outcome: crate::proto::lsps2::SessionOutcome, + entry: &DatastoreEntry, ) -> Result<()> { - (**self).finalize_session(scid, outcome).await + (**self).save_session(scid, entry).await } - async fn update_session_funding( - &self, - scid: &ShortChannelId, - channel_id: &str, - funding_psbt: &str, - ) -> Result<()> { - (**self) - .update_session_funding(scid, channel_id, funding_psbt) - .await - } - - async fn update_session_funding_txid( - &self, - scid: &ShortChannelId, - funding_txid: &str, - ) -> Result<()> { - (**self) - .update_session_funding_txid(scid, funding_txid) - .await - } - - async fn update_session_preimage( + async fn finalize_session( &self, scid: &ShortChannelId, - preimage: &str, + outcome: crate::proto::lsps2::SessionOutcome, ) -> Result<()> { - (**self).update_session_preimage(scid, preimage).await + (**self).finalize_session(scid, outcome).await } async fn list_active_sessions(&self) -> Result> { (**self).list_active_sessions().await } - - async fn update_session_forwards_index( - &self, - scid: &ShortChannelId, - index: u64, - ) -> Result<()> { - (**self).update_session_forwards_index(scid, index).await - } - - async fn reset_session_funding(&self, scid: &ShortChannelId) -> Result<()> { - (**self).reset_session_funding(scid).await - } } diff --git a/plugins/lsps-plugin/src/core/lsps2/manager.rs b/plugins/lsps-plugin/src/core/lsps2/manager.rs index a3dcb1e0d166..9fedae09b559 100644 --- a/plugins/lsps-plugin/src/core/lsps2/manager.rs +++ b/plugins/lsps-plugin/src/core/lsps2/manager.rs @@ -84,7 +84,11 @@ impl recovery .close_and_unreserve(channel_id, funding_psbt) .await?; - self.datastore.reset_session_funding(&scid).await?; + let mut entry = entry; + entry.channel_id = None; + entry.funding_psbt = None; + entry.funding_txid = None; + self.datastore.save_session(&scid, &entry).await?; } ForwardActivity::AllFailed => { self.datastore @@ -102,13 +106,14 @@ impl let handle = SessionActor::spawn_recovered_session_actor( session, + entry, initial_actions, - channel_id.clone(), + channel_id, self.executor.clone(), scid, self.datastore.clone(), recovery.clone(), - entry.forwards_updated_index, + forwards_updated_index, ); self.recovery_handles.lock().await.push(handle); @@ -230,18 +235,20 @@ impl .await .map_err(ManagerError::DatastoreLookup)?; + let peer_id = entry.peer_id.to_string(); let session = Session::new( self.config.max_parts, - entry.opening_fee_params, + entry.opening_fee_params.clone(), entry.expected_payment_size, entry.channel_capacity_msat, - entry.peer_id.to_string(), + peer_id.clone(), ); Ok(SessionActor::spawn_session_actor( session, + entry, self.executor.clone(), - entry.peer_id.to_string(), + peer_id, self.config.collect_timeout_secs, *scid, self.datastore.clone(), @@ -341,13 +348,13 @@ mod tests { impl DatastoreProvider for MockDatastore { async fn store_buy_request( &self, - _scid: &ShortChannelId, + scid: &ShortChannelId, _peer_id: &bitcoin::secp256k1::PublicKey, _offer: &OpeningFeeParams, _expected_payment_size: &Option, _channel_capacity_msat: &Msat, - ) -> anyhow::Result { - Ok(true) + ) -> anyhow::Result { + self.get_buy_request(scid).await } async fn get_buy_request(&self, scid: &ShortChannelId) -> anyhow::Result { @@ -357,39 +364,18 @@ mod tests { .ok_or_else(|| anyhow::anyhow!("not found: {scid}")) } - async fn del_buy_request(&self, _scid: &ShortChannelId) -> anyhow::Result<()> { - Ok(()) - } - - async fn finalize_session( - &self, - _scid: &ShortChannelId, - _outcome: SessionOutcome, - ) -> anyhow::Result<()> { - Ok(()) - } - - async fn update_session_funding( - &self, - _scid: &ShortChannelId, - _channel_id: &str, - _funding_psbt: &str, - ) -> anyhow::Result<()> { - Ok(()) - } - - async fn update_session_funding_txid( + async fn save_session( &self, _scid: &ShortChannelId, - _funding_txid: &str, + _entry: &DatastoreEntry, ) -> anyhow::Result<()> { Ok(()) } - async fn update_session_preimage( + async fn finalize_session( &self, _scid: &ShortChannelId, - _preimage: &str, + _outcome: SessionOutcome, ) -> anyhow::Result<()> { Ok(()) } @@ -399,18 +385,6 @@ mod tests { (k.parse::().unwrap(), v.clone()) }).collect()) } - - async fn update_session_forwards_index( - &self, - _scid: &ShortChannelId, - _index: u64, - ) -> anyhow::Result<()> { - Ok(()) - } - - async fn reset_session_funding(&self, _scid: &ShortChannelId) -> anyhow::Result<()> { - Ok(()) - } } struct MockExecutor { diff --git a/plugins/lsps-plugin/src/core/lsps2/provider.rs b/plugins/lsps-plugin/src/core/lsps2/provider.rs index ff77d607294d..fef5c44da8bc 100644 --- a/plugins/lsps-plugin/src/core/lsps2/provider.rs +++ b/plugins/lsps-plugin/src/core/lsps2/provider.rs @@ -26,40 +26,15 @@ pub trait DatastoreProvider: Send + Sync { offer: &OpeningFeeParams, expected_payment_size: &Option, channel_capacity_msat: &Msat, - ) -> Result; + ) -> Result; async fn get_buy_request(&self, scid: &ShortChannelId) -> Result; - async fn del_buy_request(&self, scid: &ShortChannelId) -> Result<()>; - async fn finalize_session(&self, scid: &ShortChannelId, outcome: SessionOutcome) -> Result<()>; + async fn save_session(&self, scid: &ShortChannelId, entry: &DatastoreEntry) -> Result<()>; - async fn update_session_funding( - &self, - scid: &ShortChannelId, - channel_id: &str, - funding_psbt: &str, - ) -> Result<()>; - - async fn update_session_funding_txid( - &self, - scid: &ShortChannelId, - funding_txid: &str, - ) -> Result<()>; - - async fn update_session_preimage(&self, scid: &ShortChannelId, preimage: &str) -> Result<()>; + async fn finalize_session(&self, scid: &ShortChannelId, outcome: SessionOutcome) -> Result<()>; - /// List all active session entries (for recovery scan). async fn list_active_sessions(&self) -> Result>; - - /// Update the forwards_updated_index for a session. - async fn update_session_forwards_index( - &self, - scid: &ShortChannelId, - index: u64, - ) -> Result<()>; - - /// Reset a session's funding fields back to None (for clean restart). - async fn reset_session_funding(&self, scid: &ShortChannelId) -> Result<()>; } /// Status of forwards on a channel, used during recovery classification. diff --git a/plugins/lsps-plugin/src/core/lsps2/service.rs b/plugins/lsps-plugin/src/core/lsps2/service.rs index d8a479e4f968..bfff87fee1e3 100644 --- a/plugins/lsps-plugin/src/core/lsps2/service.rs +++ b/plugins/lsps-plugin/src/core/lsps2/service.rs @@ -130,7 +130,7 @@ impl .channel_capacity_msat .ok_or_else(|| RpcError::internal_error("channel capacity denied by policy"))?; - let ok = self + let _entry = self .api .store_buy_request( &jit_scid, @@ -142,10 +142,6 @@ impl .await .map_err(|_| RpcError::internal_error("internal error"))?; - if !ok { - return Err(RpcError::internal_error("internal error"))?; - } - Ok(Lsps2BuyResponse { jit_channel_scid: jit_scid, // We can make this configurable if necessary. @@ -347,7 +343,7 @@ mod tests { _fee_params: &OpeningFeeParams, payment_size: &Option, _channel_capacity_msat: &Msat, - ) -> AnyResult { + ) -> AnyResult { if *self.store_error.lock().unwrap() { return Err(anyhow!("store error")); } @@ -357,46 +353,49 @@ mod tests { payment_size: *payment_size, }); - Ok(self.store_result.lock().unwrap().unwrap_or(true)) - } - - async fn get_buy_request(&self, _scid: &ShortChannelId) -> AnyResult { - unimplemented!("not needed for service tests") - } - - async fn del_buy_request(&self, _scid: &ShortChannelId) -> AnyResult<()> { - unimplemented!("not needed for service tests") - } + if !self.store_result.lock().unwrap().unwrap_or(true) { + return Err(anyhow!("duplicate SCID")); + } - async fn finalize_session( - &self, - _scid: &ShortChannelId, - _outcome: SessionOutcome, - ) -> AnyResult<()> { - unimplemented!("not needed for service tests") + Ok(DatastoreEntry { + peer_id: *peer_id, + opening_fee_params: OpeningFeeParams { + min_fee_msat: Msat(0), + proportional: Ppm(0), + valid_until: Utc::now(), + min_lifetime: 0, + max_client_to_self_delay: 0, + min_payment_size_msat: Msat(0), + max_payment_size_msat: Msat(0), + promise: Promise(String::new()), + }, + expected_payment_size: *payment_size, + channel_capacity_msat: Msat(0), + created_at: Utc::now(), + channel_id: None, + funding_psbt: None, + funding_txid: None, + preimage: None, + forwards_updated_index: None, + }) } - async fn update_session_funding( - &self, - _scid: &ShortChannelId, - _channel_id: &str, - _funding_psbt: &str, - ) -> AnyResult<()> { + async fn get_buy_request(&self, _scid: &ShortChannelId) -> AnyResult { unimplemented!("not needed for service tests") } - async fn update_session_funding_txid( + async fn save_session( &self, _scid: &ShortChannelId, - _funding_txid: &str, + _entry: &DatastoreEntry, ) -> AnyResult<()> { unimplemented!("not needed for service tests") } - async fn update_session_preimage( + async fn finalize_session( &self, _scid: &ShortChannelId, - _preimage: &str, + _outcome: SessionOutcome, ) -> AnyResult<()> { unimplemented!("not needed for service tests") } @@ -404,18 +403,6 @@ mod tests { async fn list_active_sessions(&self) -> AnyResult> { unimplemented!("not needed for service tests") } - - async fn update_session_forwards_index( - &self, - _scid: &ShortChannelId, - _index: u64, - ) -> AnyResult<()> { - unimplemented!("not needed for service tests") - } - - async fn reset_session_funding(&self, _scid: &ShortChannelId) -> AnyResult<()> { - unimplemented!("not needed for service tests") - } } fn handler(api: MockApi) -> Lsps2ServiceHandler { @@ -662,7 +649,7 @@ mod tests { } #[tokio::test] - async fn buy_handles_store_returns_false() { + async fn buy_handles_store_duplicate_error() { let api = MockApi::new() .with_blockheight(800_000) .with_store_result(false) From d54842d64c958d8853216657ad60b9d63f0d59e1 Mon Sep 17 00:00:00 2001 From: Peter Neuroth Date: Thu, 19 Mar 2026 00:11:47 +0100 Subject: [PATCH 07/11] plugins(lsps2): add EventSink trait for session event notification Add an EventSink trait that decouples session event reporting from the transport layer. Includes a composite sink and a channel-based implementation. Wires EventSink through SessionActor and SessionManager, and persists payment_hash in DatastoreEntry. --- plugins/lsps-plugin/src/cln_adapters/rpc.rs | 4 + plugins/lsps-plugin/src/core/lsps2/actor.rs | 36 ++++- .../lsps-plugin/src/core/lsps2/event_sink.rs | 130 ++++++++++++++++++ plugins/lsps-plugin/src/core/lsps2/manager.rs | 26 +++- plugins/lsps-plugin/src/core/lsps2/mod.rs | 1 + plugins/lsps-plugin/src/core/lsps2/service.rs | 1 + plugins/lsps-plugin/src/proto/lsps2.rs | 3 + plugins/lsps-plugin/src/service.rs | 2 + 8 files changed, 196 insertions(+), 7 deletions(-) create mode 100644 plugins/lsps-plugin/src/core/lsps2/event_sink.rs diff --git a/plugins/lsps-plugin/src/cln_adapters/rpc.rs b/plugins/lsps-plugin/src/cln_adapters/rpc.rs index 9e1aea5cdf99..637aea356d01 100644 --- a/plugins/lsps-plugin/src/cln_adapters/rpc.rs +++ b/plugins/lsps-plugin/src/cln_adapters/rpc.rs @@ -461,6 +461,8 @@ impl DatastoreProvider for ClnApiRpc { preimage: Option, #[serde(skip_serializing_if = "Option::is_none")] forwards_updated_index: &'a Option, + #[serde(skip_serializing_if = "Option::is_none")] + payment_hash: Option, } let ds = BorrowedDatastoreEntry { @@ -474,6 +476,7 @@ impl DatastoreProvider for ClnApiRpc { funding_txid: None, preimage: None, forwards_updated_index: &None, + payment_hash: None, }; let json_str = serde_json::to_string(&ds)?; @@ -508,6 +511,7 @@ impl DatastoreProvider for ClnApiRpc { funding_txid: None, preimage: None, forwards_updated_index: None, + payment_hash: None, }) } diff --git a/plugins/lsps-plugin/src/core/lsps2/actor.rs b/plugins/lsps-plugin/src/core/lsps2/actor.rs index 0621c9d2fd56..13a39d458c5b 100644 --- a/plugins/lsps-plugin/src/core/lsps2/actor.rs +++ b/plugins/lsps-plugin/src/core/lsps2/actor.rs @@ -1,7 +1,8 @@ use crate::{ core::lsps2::{ + event_sink::{EventSink, SessionEventEnvelope}, provider::{DatastoreProvider, ForwardActivity, RecoveryProvider}, - session::{PaymentPart, Session, SessionAction, SessionInput}, + session::{PaymentPart, Session, SessionAction, SessionEvent, SessionInput}, }, proto::{ lsps0::{Msat, ShortChannelId}, @@ -10,6 +11,8 @@ use crate::{ }; use anyhow::Result; use async_trait::async_trait; +use bitcoin::hashes::sha256::Hash as PaymentHash; +use bitcoin::hashes::Hash; use log::{debug, warn}; use std::{collections::HashMap, sync::Arc, time::Duration}; use tokio::{ @@ -134,6 +137,7 @@ pub struct SessionActor { collect_timeout_secs: u64, scid: ShortChannelId, datastore: D, + event_sink: Arc, } impl @@ -147,6 +151,7 @@ impl, ) -> ActorInboxHandle { let (tx, inbox) = mpsc::channel(128); // Should we use max_htlcs? let actor = SessionActor { @@ -162,6 +167,7 @@ impl, forwards_updated_index: Option, + event_sink: Arc, ) -> ActorInboxHandle { let (tx, inbox) = mpsc::channel(128); let handle = ActorInboxHandle { tx: tx.clone() }; @@ -194,6 +201,7 @@ impl) { + let payment_hash = match self.entry.payment_hash.as_deref() { + Some(s) => match s.parse::() { + Ok(h) => h, + Err(e) => { + warn!("malformed payment_hash in datastore for scid={}: {e}", self.scid); + PaymentHash::all_zeros() + } + }, + None => PaymentHash::all_zeros(), + }; + for event in events { + debug!("session event: {:?}", event); + self.event_sink.send(&SessionEventEnvelope { + scid: self.scid, + payment_hash, + event, + }); + } + } + fn start_collect_timeout(&mut self) { let tx = self.self_send.clone(); let timeout = Duration::from_secs(self.collect_timeout_secs); @@ -314,9 +343,7 @@ impl { - for event in &result.events { - debug!("session event: {:?}", event); - } + self.dispatch_events(result.events); for action in result.actions { self.execute_action(action); @@ -474,6 +501,7 @@ impl { + self.dispatch_events(result.events); for action in result.actions { self.execute_action(action); } diff --git a/plugins/lsps-plugin/src/core/lsps2/event_sink.rs b/plugins/lsps-plugin/src/core/lsps2/event_sink.rs new file mode 100644 index 000000000000..0911d28a813e --- /dev/null +++ b/plugins/lsps-plugin/src/core/lsps2/event_sink.rs @@ -0,0 +1,130 @@ +use crate::core::lsps2::session::SessionEvent; +use crate::proto::lsps0::ShortChannelId; +use bitcoin::hashes::sha256::Hash as PaymentHash; +use std::sync::Arc; +use tokio::sync::mpsc; + +#[derive(Debug, Clone)] +pub struct SessionEventEnvelope { + pub scid: ShortChannelId, + pub payment_hash: PaymentHash, + pub event: SessionEvent, +} + +pub trait EventSink: Send + Sync { + fn send(&self, envelope: &SessionEventEnvelope); +} + +pub struct NoopEventSink; +impl EventSink for NoopEventSink { + fn send(&self, _: &SessionEventEnvelope) {} +} + +pub struct CompositeEventSink { + sinks: Vec>, +} + +impl CompositeEventSink { + pub fn new(sinks: Vec>) -> Self { + Self { sinks } + } +} + +impl EventSink for CompositeEventSink { + fn send(&self, envelope: &SessionEventEnvelope) { + for sink in &self.sinks { + sink.send(envelope); + } + } +} + +pub struct ChannelEventSink { + tx: mpsc::UnboundedSender, +} + +impl ChannelEventSink { + pub fn new() -> (Self, mpsc::UnboundedReceiver) { + let (tx, rx) = mpsc::unbounded_channel(); + (Self { tx }, rx) + } +} + +impl EventSink for ChannelEventSink { + fn send(&self, envelope: &SessionEventEnvelope) { + let _ = self.tx.send(envelope.clone()); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::lsps2::session::SessionEvent; + use crate::proto::lsps0::ShortChannelId; + use bitcoin::hashes::sha256::Hash as PaymentHash; + use bitcoin::hashes::Hash; + use std::sync::atomic::{AtomicUsize, Ordering}; + + fn test_envelope() -> SessionEventEnvelope { + SessionEventEnvelope { + scid: ShortChannelId::from(100u64 << 40 | 1u64 << 16), + payment_hash: PaymentHash::from_byte_array([1; 32]), + event: SessionEvent::FundingChannel, + } + } + + struct CountingSink(AtomicUsize); + impl CountingSink { + fn new() -> Self { Self(AtomicUsize::new(0)) } + fn count(&self) -> usize { self.0.load(Ordering::SeqCst) } + } + impl EventSink for CountingSink { + fn send(&self, _: &SessionEventEnvelope) { + self.0.fetch_add(1, Ordering::SeqCst); + } + } + + #[test] + fn noop_sink_does_not_panic() { + let sink = NoopEventSink; + sink.send(&test_envelope()); + } + + #[test] + fn composite_fans_out_to_all_sinks() { + let s1 = Arc::new(CountingSink::new()); + let s2 = Arc::new(CountingSink::new()); + let composite = CompositeEventSink::new(vec![ + s1.clone() as Arc, + s2.clone(), + ]); + composite.send(&test_envelope()); + composite.send(&test_envelope()); + assert_eq!(s1.count(), 2); + assert_eq!(s2.count(), 2); + } + + #[test] + fn composite_with_no_sinks_does_not_panic() { + let composite = CompositeEventSink::new(vec![]); + composite.send(&test_envelope()); + } + + #[tokio::test] + async fn channel_sink_delivers_to_receiver() { + let (sink, mut rx) = ChannelEventSink::new(); + let envelope = test_envelope(); + sink.send(&envelope); + sink.send(&envelope); + let received = rx.recv().await.unwrap(); + assert_eq!(received.scid, envelope.scid); + let received2 = rx.recv().await.unwrap(); + assert_eq!(received2.scid, envelope.scid); + } + + #[test] + fn channel_sink_silently_drops_when_receiver_gone() { + let (sink, rx) = ChannelEventSink::new(); + drop(rx); + sink.send(&test_envelope()); // must not panic + } +} diff --git a/plugins/lsps-plugin/src/core/lsps2/manager.rs b/plugins/lsps-plugin/src/core/lsps2/manager.rs index 9fedae09b559..f17bdb1f236a 100644 --- a/plugins/lsps-plugin/src/core/lsps2/manager.rs +++ b/plugins/lsps-plugin/src/core/lsps2/manager.rs @@ -2,6 +2,7 @@ use super::actor::{ActionExecutor, ActorInboxHandle, HtlcResponse}; use super::provider::{DatastoreProvider, ForwardActivity, RecoveryProvider}; use super::session::{PaymentPart, Session}; use crate::core::lsps2::actor::SessionActor; +use crate::core::lsps2::event_sink::EventSink; use crate::proto::lsps0::ShortChannelId; use crate::proto::lsps2::SessionOutcome; pub use bitcoin::hashes::sha256::Hash as PaymentHash; @@ -39,18 +40,20 @@ pub struct SessionManager { datastore: Arc, executor: Arc, config: SessionConfig, + event_sink: Arc, } impl SessionManager { - pub fn new(datastore: Arc, executor: Arc, config: SessionConfig) -> Self { + pub fn new(datastore: Arc, executor: Arc, config: SessionConfig, event_sink: Arc) -> Self { Self { sessions: Mutex::new(HashMap::new()), recovery_handles: Mutex::new(Vec::new()), datastore, executor, config, + event_sink, } } @@ -114,6 +117,7 @@ impl self.datastore.clone(), recovery.clone(), forwards_updated_index, + self.event_sink.clone(), ); self.recovery_handles.lock().await.push(handle); @@ -144,7 +148,7 @@ impl if let Some(handle) = sessions.get(&payment_hash) { handle.clone() } else { - let handle = self.create_session(&scid).await?; + let handle = self.create_session(&scid, &payment_hash).await?; sessions.insert(payment_hash, handle.clone()); handle } @@ -228,13 +232,20 @@ impl async fn create_session( &self, scid: &ShortChannelId, + payment_hash: &PaymentHash, ) -> Result { - let entry = self + let mut entry = self .datastore .get_buy_request(scid) .await .map_err(ManagerError::DatastoreLookup)?; + entry.payment_hash = Some(payment_hash.to_string()); + self.datastore + .save_session(scid, &entry) + .await + .map_err(ManagerError::DatastoreLookup)?; + let peer_id = entry.peer_id.to_string(); let session = Session::new( self.config.max_parts, @@ -252,6 +263,7 @@ impl self.config.collect_timeout_secs, *scid, self.datastore.clone(), + self.event_sink.clone(), )) } @@ -264,6 +276,7 @@ impl #[cfg(test)] mod tests { use super::*; + use crate::core::lsps2::event_sink::NoopEventSink; use crate::core::lsps2::provider::{ChannelRecoveryInfo, ForwardActivity, RecoveryProvider}; use crate::proto::lsps0::{Msat, Ppm}; use crate::proto::lsps2::{DatastoreEntry, OpeningFeeParams, Promise, SessionOutcome}; @@ -320,6 +333,7 @@ mod tests { funding_txid: None, preimage: None, forwards_updated_index: None, + payment_hash: None, } } @@ -440,6 +454,7 @@ mod tests { max_parts: 3, ..SessionConfig::default() }, + Arc::new(NoopEventSink), )) } @@ -690,6 +705,7 @@ mod tests { Arc::new(ds), Arc::new(MockExecutor { fund_succeeds: true }), SessionConfig::default(), + Arc::new(NoopEventSink), )); mgr.recover(Arc::new(MockRecoveryProvider::default())) @@ -705,6 +721,7 @@ mod tests { Arc::new(ds), Arc::new(MockExecutor { fund_succeeds: true }), SessionConfig::default(), + Arc::new(NoopEventSink), )); mgr.recover(Arc::new(MockRecoveryProvider::default())).await.unwrap(); @@ -732,6 +749,7 @@ mod tests { Arc::new(ds), Arc::new(MockExecutor { fund_succeeds: true }), SessionConfig::default(), + Arc::new(NoopEventSink), )); let recovery = Arc::new(MockRecoveryProvider { @@ -756,6 +774,7 @@ mod tests { Arc::new(ds), Arc::new(MockExecutor { fund_succeeds: true }), SessionConfig::default(), + Arc::new(NoopEventSink), )); let recovery = Arc::new(MockRecoveryProvider { @@ -780,6 +799,7 @@ mod tests { Arc::new(ds), Arc::new(MockExecutor { fund_succeeds: true }), SessionConfig::default(), + Arc::new(NoopEventSink), )); let recovery = Arc::new(MockRecoveryProvider { diff --git a/plugins/lsps-plugin/src/core/lsps2/mod.rs b/plugins/lsps-plugin/src/core/lsps2/mod.rs index 22cc81b36983..eadfdadc890d 100644 --- a/plugins/lsps-plugin/src/core/lsps2/mod.rs +++ b/plugins/lsps-plugin/src/core/lsps2/mod.rs @@ -1,4 +1,5 @@ pub mod actor; +pub mod event_sink; pub mod manager; pub mod provider; pub mod service; diff --git a/plugins/lsps-plugin/src/core/lsps2/service.rs b/plugins/lsps-plugin/src/core/lsps2/service.rs index bfff87fee1e3..5fe47973df0f 100644 --- a/plugins/lsps-plugin/src/core/lsps2/service.rs +++ b/plugins/lsps-plugin/src/core/lsps2/service.rs @@ -377,6 +377,7 @@ mod tests { funding_txid: None, preimage: None, forwards_updated_index: None, + payment_hash: None, }) } diff --git a/plugins/lsps-plugin/src/proto/lsps2.rs b/plugins/lsps-plugin/src/proto/lsps2.rs index e8b99315f953..e5d6ad729be0 100644 --- a/plugins/lsps-plugin/src/proto/lsps2.rs +++ b/plugins/lsps-plugin/src/proto/lsps2.rs @@ -355,6 +355,9 @@ pub struct DatastoreEntry { #[serde(skip_serializing_if = "Option::is_none")] #[serde(default)] pub forwards_updated_index: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(default)] + pub payment_hash: Option, } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] diff --git a/plugins/lsps-plugin/src/service.rs b/plugins/lsps-plugin/src/service.rs index 4f31c9bcf322..0ec8e5cf9650 100644 --- a/plugins/lsps-plugin/src/service.rs +++ b/plugins/lsps-plugin/src/service.rs @@ -9,6 +9,7 @@ use cln_lsps::{ core::{ lsps2::{ actor::HtlcResponse, + event_sink::NoopEventSink, manager::{PaymentHash, SessionConfig, SessionManager}, provider::{DatastoreProvider, RecoveryProvider}, session::PaymentPart, @@ -67,6 +68,7 @@ impl State { collect_timeout_secs, ..SessionConfig::default() }, + Arc::new(NoopEventSink), )); Self { lsps_service, From 422cf9f327b07bab3731c192a6e6092062d25ed8 Mon Sep 17 00:00:00 2001 From: Peter Neuroth Date: Thu, 19 Mar 2026 00:12:03 +0100 Subject: [PATCH 08/11] plugins(lsps2): decouple lsps2 crate from CLN-specific types Replace CLN-specific types (cln_rpc PublicKey, ShortChannelId alias) with standalone alternatives, feature-gate CLN dependencies behind a "cln" feature flag, split ClnApiRpc into focused adapter structs, and refactor Lsps2ServiceHandler generics for cleaner trait boundaries. This makes the lsps2 core reusable outside of CLN. --- plugins/lsps-plugin/Cargo.toml | 12 +- plugins/lsps-plugin/src/client.rs | 2 +- plugins/lsps-plugin/src/cln_adapters/mod.rs | 5 + plugins/lsps-plugin/src/cln_adapters/rpc.rs | 309 ++++++++++++------ plugins/lsps-plugin/src/core/lsps2/manager.rs | 9 +- plugins/lsps-plugin/src/core/lsps2/service.rs | 41 ++- plugins/lsps-plugin/src/lib.rs | 1 + plugins/lsps-plugin/src/proto/lsps0.rs | 76 ++++- plugins/lsps-plugin/src/proto/lsps2.rs | 2 +- plugins/lsps-plugin/src/service.rs | 35 +- 10 files changed, 359 insertions(+), 133 deletions(-) diff --git a/plugins/lsps-plugin/Cargo.toml b/plugins/lsps-plugin/Cargo.toml index d1b99ada32ee..1dd6c13fec4e 100644 --- a/plugins/lsps-plugin/Cargo.toml +++ b/plugins/lsps-plugin/Cargo.toml @@ -6,18 +6,24 @@ edition = "2021" [[bin]] name = "cln-lsps-client" path = "src/client.rs" +required-features = ["cln"] [[bin]] name = "cln-lsps-service" path = "src/service.rs" +required-features = ["cln"] + +[features] +default = ["cln"] +cln = ["cln-plugin", "cln-rpc"] [dependencies] anyhow = "1.0" async-trait = "0.1" -bitcoin = "0.31" +bitcoin = { version = "0.31", features = ["serde"] } chrono = { version= "0.4.42", features = ["serde"] } -cln-plugin = { workspace = true } -cln-rpc = { workspace = true } +cln-plugin = { workspace = true, optional = true } +cln-rpc = { workspace = true, optional = true } hex = "0.4" log = "0.4" paste = "1.0.15" diff --git a/plugins/lsps-plugin/src/client.rs b/plugins/lsps-plugin/src/client.rs index ceb06391398f..bb15cd7ece9b 100644 --- a/plugins/lsps-plugin/src/client.rs +++ b/plugins/lsps-plugin/src/client.rs @@ -482,7 +482,7 @@ async fn on_lsps_lsps2_invoice( // 5. Approve jit_channel_scid for a jit channel opening. let appr_req = ClnRpcLsps2Approve { lsp_id: req.lsp_id, - jit_channel_scid: buy_res.jit_channel_scid, + jit_channel_scid: buy_res.jit_channel_scid.into(), payment_hash: public_inv.payment_hash.to_string(), client_trusts_lsp: Some(buy_res.client_trusts_lsp), }; diff --git a/plugins/lsps-plugin/src/cln_adapters/mod.rs b/plugins/lsps-plugin/src/cln_adapters/mod.rs index 063690099ca2..821b16f474b0 100644 --- a/plugins/lsps-plugin/src/cln_adapters/mod.rs +++ b/plugins/lsps-plugin/src/cln_adapters/mod.rs @@ -3,3 +3,8 @@ pub mod rpc; pub mod sender; pub mod state; pub mod types; + +pub use rpc::{ + ClnActionExecutor, ClnBlockheight, ClnDatastore, ClnPolicyProvider, ClnRecoveryProvider, + ClnRpcClient, +}; diff --git a/plugins/lsps-plugin/src/cln_adapters/rpc.rs b/plugins/lsps-plugin/src/cln_adapters/rpc.rs index 637aea356d01..5b4b814dc421 100644 --- a/plugins/lsps-plugin/src/cln_adapters/rpc.rs +++ b/plugins/lsps-plugin/src/cln_adapters/rpc.rs @@ -46,22 +46,26 @@ pub const DS_SESSIONS_KEY: &str = "sessions"; pub const DS_ACTIVE_KEY: &str = "active"; pub const DS_FINALIZED_KEY: &str = "finalized"; +// --------------------------------------------------------------------------- +// ClnRpcClient — shared connection helper +// --------------------------------------------------------------------------- + #[derive(Clone)] -pub struct ClnApiRpc { +pub struct ClnRpcClient { rpc_path: PathBuf, } -impl ClnApiRpc { +impl ClnRpcClient { pub fn new(rpc_path: PathBuf) -> Self { Self { rpc_path } } - async fn create_rpc(&self) -> Result { + pub async fn create_rpc(&self) -> Result { // Note: Add retry and backoff, be nicer than just failing. ClnRpc::new(&self.rpc_path).await } - async fn poll_channel_ready( + pub async fn poll_channel_ready( &self, channel_id: &Sha256, timeout: Duration, @@ -82,7 +86,7 @@ impl ClnApiRpc { } } - async fn check_channel_normal(&self, channel_id: &Sha256) -> Result { + pub async fn check_channel_normal(&self, channel_id: &Sha256) -> Result { let mut rpc = self.create_rpc().await?; let r = rpc .call_typed(&ListpeerchannelsRequest { @@ -98,37 +102,7 @@ impl ClnApiRpc { .is_some_and(|ch| ch.state == ChannelState::CHANNELD_NORMAL)) } - async fn cleanup_failed_funding(&self, peer_id: &PublicKey, psbt: &str) { - if let Err(e) = self.unreserve_inputs(psbt).await { - warn!("cleanup: unreserveinputs for psbt={psbt} failed: {e}"); - } - if let Err(e) = self.cancel_fundchannel(peer_id).await { - warn!("cleanup: fundchannel_cancel failed: {e}"); - } - } - - async fn unreserve_inputs(&self, psbt: &str) -> Result<()> { - let mut rpc = self.create_rpc().await?; - rpc.call_typed(&UnreserveinputsRequest { - reserve: None, - psbt: psbt.to_string(), - }) - .await - .with_context(|| "calling unreserveinputs")?; - Ok(()) - } - - async fn cancel_fundchannel(&self, peer_id: &PublicKey) -> Result<()> { - let mut rpc = self.create_rpc().await?; - rpc.call_typed(&FundchannelCancelRequest { - id: peer_id.to_owned(), - }) - .await - .with_context(|| "calling fundchannel_cancel")?; - Ok(()) - } - - async fn connect_with_retry(&self, peer_id: &str, timeout: Duration) -> Result<()> { + pub async fn connect_with_retry(&self, peer_id: &str, timeout: Duration) -> Result<()> { let deadline = tokio::time::Instant::now() + timeout; let mut backoff = Duration::from_secs(1); let max_backoff = Duration::from_secs(10); @@ -156,29 +130,12 @@ impl ClnApiRpc { } } - async fn del_buy_request(&self, scid: &ShortChannelId) -> Result<()> { - let mut rpc = self.create_rpc().await?; - let key = vec![ - DS_MAIN_KEY.to_string(), - DS_SUB_KEY.to_string(), - DS_SESSIONS_KEY.to_string(), - DS_ACTIVE_KEY.to_string(), - scid.to_string(), - ]; - - let _ = rpc - .call_typed(&DeldatastoreRequest { - generation: None, - key, - }) - .await; - - Ok(()) - } - /// Get the short_channel_id for a channel, needed for listforwards queries. /// Falls back to alias.local for unconfirmed JIT channels. - async fn get_channel_scid(&self, channel_id: &str) -> Result> { + pub async fn get_channel_scid( + &self, + channel_id: &str, + ) -> Result> { let mut rpc = self.create_rpc().await?; let peers = rpc .call_typed(&ListpeerchannelsRequest { @@ -199,15 +156,60 @@ impl ClnApiRpc { } Ok(None) } + + pub async fn unreserve_inputs(&self, psbt: &str) -> Result<()> { + let mut rpc = self.create_rpc().await?; + rpc.call_typed(&UnreserveinputsRequest { + reserve: None, + psbt: psbt.to_string(), + }) + .await + .with_context(|| "calling unreserveinputs")?; + Ok(()) + } } +// --------------------------------------------------------------------------- +// ClnActionExecutor — implements ActionExecutor +// --------------------------------------------------------------------------- + /// Converts msat to sat, rounding up to avoid underfunding. fn msat_to_sat_ceil(msat: u64) -> u64 { msat.div_ceil(1000) } +#[derive(Clone)] +pub struct ClnActionExecutor { + rpc: ClnRpcClient, +} + +impl ClnActionExecutor { + pub fn new(rpc: ClnRpcClient) -> Self { + Self { rpc } + } + + async fn cleanup_failed_funding(&self, peer_id: &PublicKey, psbt: &str) { + if let Err(e) = self.rpc.unreserve_inputs(psbt).await { + warn!("cleanup: unreserveinputs for psbt={psbt} failed: {e}"); + } + if let Err(e) = self.cancel_fundchannel(peer_id).await { + warn!("cleanup: fundchannel_cancel failed: {e}"); + } + } + + async fn cancel_fundchannel(&self, peer_id: &PublicKey) -> Result<()> { + let mut rpc = self.rpc.create_rpc().await?; + rpc.call_typed(&FundchannelCancelRequest { + id: peer_id.to_owned(), + }) + .await + .with_context(|| "calling fundchannel_cancel")?; + Ok(()) + } +} + #[async_trait] -impl ActionExecutor for ClnApiRpc { +impl ActionExecutor for ClnActionExecutor { async fn fund_channel( &self, peer_id: String, @@ -219,10 +221,10 @@ impl ActionExecutor for ClnApiRpc { .with_context(|| format!("parsing peer_id '{peer_id}'"))?; let channel_sat = msat_to_sat_ceil(channel_size.msat()); - self.connect_with_retry(&peer_id, Duration::from_secs(90)) + self.rpc.connect_with_retry(&peer_id, Duration::from_secs(90)) .await?; - let mut rpc = self.create_rpc().await?; + let mut rpc = self.rpc.create_rpc().await?; let start_res = rpc .call_typed(&FundchannelStartRequest { id: pk, @@ -240,7 +242,7 @@ impl ActionExecutor for ClnApiRpc { let funding_address = start_res.funding_address; // Reserve input and add to tx - let mut rpc = self.create_rpc().await?; + let mut rpc = self.rpc.create_rpc().await?; let fundpsbt_res = match rpc .call_typed(&FundpsbtRequest { satoshi: AmountOrAll::Amount(Amount::from_sat(channel_sat)), @@ -297,6 +299,7 @@ impl ActionExecutor for ClnApiRpc { let channel_id = complete_res.channel_id; if let Err(e) = self + .rpc .poll_channel_ready( &channel_id, Duration::from_secs(120), @@ -320,7 +323,7 @@ impl ActionExecutor for ClnApiRpc { let sha = channel_id .parse::() .with_context(|| format!("parsing channel_id '{channel_id}'"))?; - let mut rpc = self.create_rpc().await?; + let mut rpc = self.rpc.create_rpc().await?; let list_res = rpc .call_typed(&ListpeerchannelsRequest { channel_id: Some(sha), @@ -344,7 +347,7 @@ impl ActionExecutor for ClnApiRpc { } } - let mut rpc = self.create_rpc().await?; + let mut rpc = self.rpc.create_rpc().await?; let sign_res = rpc .call_typed(&SignpsbtRequest { psbt: funding_psbt, @@ -376,7 +379,7 @@ impl ActionExecutor for ClnApiRpc { } let close_res = { - let mut rpc = self.create_rpc().await?; + let mut rpc = self.rpc.create_rpc().await?; rpc.call_typed(&CloseRequest { destination: None, fee_negotiation_step: None, @@ -394,7 +397,7 @@ impl ActionExecutor for ClnApiRpc { warn!("abandon_session: close failed for channel_id={channel_id}: {e}"); } - let unreserve_res = self.unreserve_inputs(&funding_psbt).await; + let unreserve_res = self.rpc.unreserve_inputs(&funding_psbt).await; if let Err(e) = &unreserve_res { warn!("abandon_session: unreserveinputs failed for funding_psbt={funding_psbt}: {e}"); } @@ -412,7 +415,7 @@ impl ActionExecutor for ClnApiRpc { async fn disconnect(&self, peer_id: String) -> anyhow::Result<()> { let pk = PublicKey::from_str(&peer_id) .with_context(|| format!("parsing peer_id '{peer_id}'"))?; - let mut rpc = self.create_rpc().await?; + let mut rpc = self.rpc.create_rpc().await?; let _ = rpc .call_typed(&DisconnectRequest { id: pk, @@ -427,12 +430,47 @@ impl ActionExecutor for ClnApiRpc { let sha = channel_id .parse::() .with_context(|| format!("parsing channel_id '{channel_id}'"))?; - self.check_channel_normal(&sha).await + self.rpc.check_channel_normal(&sha).await + } +} + +// --------------------------------------------------------------------------- +// ClnDatastore — implements DatastoreProvider +// --------------------------------------------------------------------------- + +#[derive(Clone)] +pub struct ClnDatastore { + rpc: ClnRpcClient, +} + +impl ClnDatastore { + pub fn new(rpc: ClnRpcClient) -> Self { + Self { rpc } + } + + async fn del_buy_request(&self, scid: &ShortChannelId) -> Result<()> { + let mut rpc = self.rpc.create_rpc().await?; + let key = vec![ + DS_MAIN_KEY.to_string(), + DS_SUB_KEY.to_string(), + DS_SESSIONS_KEY.to_string(), + DS_ACTIVE_KEY.to_string(), + scid.to_string(), + ]; + + let _ = rpc + .call_typed(&DeldatastoreRequest { + generation: None, + key, + }) + .await; + + Ok(()) } } #[async_trait] -impl DatastoreProvider for ClnApiRpc { +impl DatastoreProvider for ClnDatastore { async fn store_buy_request( &self, scid: &ShortChannelId, @@ -442,7 +480,7 @@ impl DatastoreProvider for ClnApiRpc { channel_capacity_msat: &Msat, ) -> Result { let created_at = chrono::Utc::now(); - let mut rpc = self.create_rpc().await?; + let mut rpc = self.rpc.create_rpc().await?; #[derive(Serialize)] struct BorrowedDatastoreEntry<'a> { peer_id: &'a PublicKey, @@ -516,7 +554,7 @@ impl DatastoreProvider for ClnApiRpc { } async fn get_buy_request(&self, scid: &ShortChannelId) -> Result { - let mut rpc = self.create_rpc().await?; + let mut rpc = self.rpc.create_rpc().await?; let key = vec![ DS_MAIN_KEY.to_string(), DS_SUB_KEY.to_string(), @@ -537,7 +575,7 @@ impl DatastoreProvider for ClnApiRpc { async fn save_session(&self, scid: &ShortChannelId, entry: &DatastoreEntry) -> Result<()> { let json_str = serde_json::to_string(entry)?; - let mut rpc = self.create_rpc().await?; + let mut rpc = self.rpc.create_rpc().await?; rpc.call_typed(&DatastoreRequest { generation: None, hex: None, @@ -572,7 +610,7 @@ impl DatastoreProvider for ClnApiRpc { }; let json_str = serde_json::to_string(&finalized)?; - let mut rpc = self.create_rpc().await?; + let mut rpc = self.rpc.create_rpc().await?; let key = vec![ DS_MAIN_KEY.to_string(), DS_SUB_KEY.to_string(), @@ -595,7 +633,7 @@ impl DatastoreProvider for ClnApiRpc { } async fn list_active_sessions(&self) -> Result> { - let mut rpc = self.create_rpc().await?; + let mut rpc = self.rpc.create_rpc().await?; let prefix = vec![ DS_MAIN_KEY.to_string(), DS_SUB_KEY.to_string(), @@ -622,20 +660,63 @@ impl DatastoreProvider for ClnApiRpc { } } +// --------------------------------------------------------------------------- +// ClnBlockheight — implements BlockheightProvider +// --------------------------------------------------------------------------- + +#[derive(Clone)] +pub struct ClnBlockheight { + rpc: ClnRpcClient, +} + +impl ClnBlockheight { + pub fn new(rpc: ClnRpcClient) -> Self { + Self { rpc } + } +} + #[async_trait] -impl Lsps2PolicyProvider for ClnApiRpc { +impl BlockheightProvider for ClnBlockheight { + async fn get_blockheight(&self) -> Result { + let mut rpc = self.rpc.create_rpc().await?; + let info = rpc + .call_typed(&GetinfoRequest {}) + .await + .map_err(anyhow::Error::new) + .with_context(|| "calling getinfo")?; + Ok(info.blockheight) + } +} + +// --------------------------------------------------------------------------- +// ClnPolicyProvider — implements Lsps2PolicyProvider +// --------------------------------------------------------------------------- + +#[derive(Clone)] +pub struct ClnPolicyProvider { + rpc: ClnRpcClient, +} + +impl ClnPolicyProvider { + pub fn new(rpc: ClnRpcClient) -> Self { + Self { rpc } + } +} + +#[async_trait] +impl Lsps2PolicyProvider for ClnPolicyProvider { async fn get_info( &self, request: &Lsps2PolicyGetInfoRequest, ) -> Result { - let mut rpc = self.create_rpc().await?; + let mut rpc = self.rpc.create_rpc().await?; rpc.call_raw("lsps2-policy-getpolicy", request) .await .context("failed to call lsps2-policy-getpolicy") } async fn buy(&self, request: &Lsps2PolicyBuyRequest) -> Result { - let mut rpc = self.create_rpc().await?; + let mut rpc = self.rpc.create_rpc().await?; rpc.call_raw("lsps2-policy-buy", request) .await .map_err(anyhow::Error::new) @@ -643,24 +724,26 @@ impl Lsps2PolicyProvider for ClnApiRpc { } } -#[async_trait] -impl BlockheightProvider for ClnApiRpc { - async fn get_blockheight(&self) -> Result { - let mut rpc = self.create_rpc().await?; - let info = rpc - .call_typed(&GetinfoRequest {}) - .await - .map_err(anyhow::Error::new) - .with_context(|| "calling getinfo")?; - Ok(info.blockheight) +// --------------------------------------------------------------------------- +// ClnRecoveryProvider — implements RecoveryProvider +// --------------------------------------------------------------------------- + +#[derive(Clone)] +pub struct ClnRecoveryProvider { + rpc: ClnRpcClient, +} + +impl ClnRecoveryProvider { + pub fn new(rpc: ClnRpcClient) -> Self { + Self { rpc } } } #[async_trait] -impl RecoveryProvider for ClnApiRpc { +impl RecoveryProvider for ClnRecoveryProvider { async fn get_forward_activity(&self, channel_id: &str) -> Result { // Check historical forwards via listforwards using out_channel filter. - let scid = match self.get_channel_scid(channel_id).await? { + let scid = match self.rpc.get_channel_scid(channel_id).await? { Some(s) => s, None => { // Channel has no scid yet — no forwards possible. @@ -668,7 +751,7 @@ impl RecoveryProvider for ClnApiRpc { } }; - let mut rpc = self.create_rpc().await?; + let mut rpc = self.rpc.create_rpc().await?; let fwd_res = rpc .call_typed(&ListforwardsRequest { in_channel: None, @@ -710,7 +793,7 @@ impl RecoveryProvider for ClnApiRpc { let sha = channel_id .parse::() .with_context(|| format!("parsing channel_id '{channel_id}'"))?; - let mut rpc = self.create_rpc().await?; + let mut rpc = self.rpc.create_rpc().await?; let list_res = rpc .call_typed(&ListpeerchannelsRequest { channel_id: Some(sha), @@ -740,8 +823,44 @@ impl RecoveryProvider for ClnApiRpc { } async fn close_and_unreserve(&self, channel_id: &str, funding_psbt: &str) -> Result<()> { - self.abandon_session(channel_id.to_string(), funding_psbt.to_string()) + let sha = channel_id.parse::() + .with_context(|| format!("parsing channel_id '{channel_id}'"))?; + if !self.rpc.check_channel_normal(&sha).await.unwrap_or(false) { + return Ok(()); + } + + let close_res = { + let mut rpc = self.rpc.create_rpc().await?; + rpc.call_typed(&CloseRequest { + destination: None, + fee_negotiation_step: None, + force_lease_closed: None, + unilateraltimeout: Some(1), + wrong_funding: None, + feerange: None, + id: channel_id.to_string(), + }) .await + .with_context(|| format!("calling close for channel_id={channel_id}")) + }; + + if let Err(e) = &close_res { + warn!("close_and_unreserve: close failed for channel_id={channel_id}: {e}"); + } + + let unreserve_res = self.rpc.unreserve_inputs(funding_psbt).await; + if let Err(e) = &unreserve_res { + warn!("close_and_unreserve: unreserveinputs failed: {e}"); + } + + match (close_res, unreserve_res) { + (Ok(_), Ok(())) => Ok(()), + (Err(e), Ok(())) => Err(e), + (Ok(_), Err(e)) => Err(e), + (Err(ce), Err(ue)) => Err(anyhow::anyhow!( + "close_and_unreserve failed: close: {ce}; unreserve: {ue}" + )), + } } async fn wait_for_forward_resolution( @@ -750,11 +869,11 @@ impl RecoveryProvider for ClnApiRpc { from_index: u64, ) -> Result<(ForwardActivity, u64)> { // Get the scid for this channel so we can match wait responses. - let scid = self.get_channel_scid(channel_id).await?; + let scid = self.rpc.get_channel_scid(channel_id).await?; let mut next_index = from_index + 1; loop { - let mut rpc = self.create_rpc().await?; + let mut rpc = self.rpc.create_rpc().await?; let wait_res = rpc .call_typed(&WaitRequest { subsystem: WaitSubsystem::FORWARDS, @@ -803,6 +922,10 @@ impl RecoveryProvider for ClnApiRpc { } } +// --------------------------------------------------------------------------- +// Datastore helpers (standalone) +// --------------------------------------------------------------------------- + #[derive(Debug)] pub enum DsError { /// No datastore entry with this exact key. diff --git a/plugins/lsps-plugin/src/core/lsps2/manager.rs b/plugins/lsps-plugin/src/core/lsps2/manager.rs index f17bdb1f236a..3fc2dc0a2a78 100644 --- a/plugins/lsps-plugin/src/core/lsps2/manager.rs +++ b/plugins/lsps-plugin/src/core/lsps2/manager.rs @@ -301,11 +301,10 @@ mod tests { ShortChannelId::from(999u64 << 40 | 9u64 << 16 | 9) } - fn test_peer_id() -> cln_rpc::primitives::PublicKey { - serde_json::from_value(serde_json::json!( - "0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798" - )) - .unwrap() + fn test_peer_id() -> bitcoin::secp256k1::PublicKey { + "0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798" + .parse() + .unwrap() } fn opening_fee_params(min_fee_msat: u64) -> OpeningFeeParams { diff --git a/plugins/lsps-plugin/src/core/lsps2/service.rs b/plugins/lsps-plugin/src/core/lsps2/service.rs index 5fe47973df0f..e38b0b23480f 100644 --- a/plugins/lsps-plugin/src/core/lsps2/service.rs +++ b/plugins/lsps-plugin/src/core/lsps2/service.rs @@ -49,30 +49,42 @@ where } } -pub struct Lsps2ServiceHandler { - pub api: Arc, +pub struct Lsps2ServiceHandler { + pub datastore: Arc, + pub blockheight: Arc, + pub policy: Arc

, pub promise_secret: [u8; 32], } -impl Lsps2ServiceHandler { - pub fn new(api: Arc, promise_seret: &[u8; 32]) -> Self { +impl Lsps2ServiceHandler { + pub fn new( + datastore: Arc, + blockheight: Arc, + policy: Arc

, + promise_secret: &[u8; 32], + ) -> Self { Lsps2ServiceHandler { - api, - promise_secret: promise_seret.to_owned(), + datastore, + blockheight, + policy, + promise_secret: promise_secret.to_owned(), } } } #[async_trait] -impl Lsps2Handler - for Lsps2ServiceHandler +impl Lsps2Handler for Lsps2ServiceHandler +where + D: DatastoreProvider + 'static, + B: BlockheightProvider + 'static, + P: Lsps2PolicyProvider + 'static, { async fn handle_get_info( &self, request: Lsps2GetInfoRequest, ) -> std::result::Result { let res_data = self - .api + .policy .get_info(&Lsps2PolicyGetInfoRequest { token: request.token.clone(), }) @@ -108,7 +120,7 @@ impl // Generate a tmp scid to identify jit channel request in htlc. let blockheight = self - .api + .blockheight .get_blockheight() .await .map_err(|_| RpcError::internal_error("internal error"))?; @@ -118,7 +130,7 @@ impl let jit_scid = ShortChannelId::generate_jit(blockheight, 12); // Approximately 2 hours in the future. let ch_cap_res = self - .api + .policy .buy(&Lsps2PolicyBuyRequest { opening_fee_params: fee_params.clone(), payment_size_msat: request.payment_size_msat, @@ -131,7 +143,7 @@ impl .ok_or_else(|| RpcError::internal_error("channel capacity denied by policy"))?; let _entry = self - .api + .datastore .store_buy_request( &jit_scid, &peer_id, @@ -406,8 +418,9 @@ mod tests { } } - fn handler(api: MockApi) -> Lsps2ServiceHandler { - Lsps2ServiceHandler::new(Arc::new(api), &test_secret()) + fn handler(api: MockApi) -> Lsps2ServiceHandler { + let api = Arc::new(api); + Lsps2ServiceHandler::new(api.clone(), api.clone(), api, &test_secret()) } #[tokio::test] diff --git a/plugins/lsps-plugin/src/lib.rs b/plugins/lsps-plugin/src/lib.rs index e1f5e07f4303..72174fc05f04 100644 --- a/plugins/lsps-plugin/src/lib.rs +++ b/plugins/lsps-plugin/src/lib.rs @@ -1,3 +1,4 @@ +#[cfg(feature = "cln")] pub mod cln_adapters; pub mod core; pub mod proto; diff --git a/plugins/lsps-plugin/src/proto/lsps0.rs b/plugins/lsps-plugin/src/proto/lsps0.rs index 96ed4ef068db..20d6f2857268 100644 --- a/plugins/lsps-plugin/src/proto/lsps0.rs +++ b/plugins/lsps-plugin/src/proto/lsps0.rs @@ -197,9 +197,79 @@ impl core::fmt::Display for Ppm { } } -/// Represents a short channel id as defined in LSPS0.scid. Matches with the -/// implementation in cln_rpc. -pub type ShortChannelId = cln_rpc::primitives::ShortChannelId; +/// Represents a short channel id as defined in LSPS0.scid. +/// Format: `{block}x{txindex}x{outnum}` encoding a u64 as +/// `(block << 40) | (txindex << 16) | outnum`. +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct ShortChannelId(u64); + +impl ShortChannelId { + pub fn block(&self) -> u32 { + (self.0 >> 40) as u32 & 0xFFFFFF + } + pub fn txindex(&self) -> u32 { + (self.0 >> 16) as u32 & 0xFFFFFF + } + pub fn outnum(&self) -> u16 { + self.0 as u16 & 0xFFFF + } + pub fn to_u64(&self) -> u64 { + self.0 + } +} + +impl From for ShortChannelId { + fn from(v: u64) -> Self { + ShortChannelId(v) + } +} + +impl core::fmt::Display for ShortChannelId { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{}x{}x{}", self.block(), self.txindex(), self.outnum()) + } +} + +impl core::str::FromStr for ShortChannelId { + type Err = String; + fn from_str(s: &str) -> std::result::Result { + let parts: Vec<&str> = s.split('x').collect(); + if parts.len() != 3 { + return Err(format!("Malformed short_channel_id: expected 3 parts, got {}", parts.len())); + } + let block: u64 = parts[0].parse().map_err(|e| format!("bad block: {e}"))?; + let txindex: u64 = parts[1].parse().map_err(|e| format!("bad txindex: {e}"))?; + let outnum: u64 = parts[2].parse().map_err(|e| format!("bad outnum: {e}"))?; + Ok(ShortChannelId((block << 40) | (txindex << 16) | outnum)) + } +} + +impl serde::Serialize for ShortChannelId { + fn serialize(&self, serializer: S) -> std::result::Result { + serializer.serialize_str(&self.to_string()) + } +} + +impl<'de> serde::Deserialize<'de> for ShortChannelId { + fn deserialize>(deserializer: D) -> std::result::Result { + let s: String = serde::Deserialize::deserialize(deserializer)?; + s.parse().map_err(serde::de::Error::custom) + } +} + +#[cfg(feature = "cln")] +impl From for cln_rpc::primitives::ShortChannelId { + fn from(scid: ShortChannelId) -> Self { + cln_rpc::primitives::ShortChannelId::from(scid.0) + } +} + +#[cfg(feature = "cln")] +impl From for ShortChannelId { + fn from(scid: cln_rpc::primitives::ShortChannelId) -> Self { + ShortChannelId(scid.to_u64()) + } +} /// Represents a datetime as defined in LSPS0.datetime. Uses ISO8601 in UTC /// timezone. diff --git a/plugins/lsps-plugin/src/proto/lsps2.rs b/plugins/lsps-plugin/src/proto/lsps2.rs index e5d6ad729be0..a8c3637b19c6 100644 --- a/plugins/lsps-plugin/src/proto/lsps2.rs +++ b/plugins/lsps-plugin/src/proto/lsps2.rs @@ -337,7 +337,7 @@ impl PolicyOpeningFeeParams { #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct DatastoreEntry { - pub peer_id: cln_rpc::primitives::PublicKey, + pub peer_id: bitcoin::secp256k1::PublicKey, pub opening_fee_params: OpeningFeeParams, #[serde(skip_serializing_if = "Option::is_none")] pub expected_payment_size: Option, diff --git a/plugins/lsps-plugin/src/service.rs b/plugins/lsps-plugin/src/service.rs index 0ec8e5cf9650..901ecc3529ed 100644 --- a/plugins/lsps-plugin/src/service.rs +++ b/plugins/lsps-plugin/src/service.rs @@ -3,7 +3,9 @@ use bitcoin::hashes::Hash; use chrono::Utc; use cln_lsps::{ cln_adapters::{ - hooks::service_custommsg_hook, rpc::ClnApiRpc, sender::ClnSender, state::ServiceState, + hooks::service_custommsg_hook, + rpc::{ClnActionExecutor, ClnBlockheight, ClnDatastore, ClnPolicyProvider, ClnRecoveryProvider, ClnRpcClient}, + sender::ClnSender, state::ServiceState, types::HtlcAcceptedRequest, }, core::{ @@ -51,19 +53,25 @@ struct State { lsps_service: Arc, sender: ClnSender, lsps2_enabled: bool, - api: Arc, - session_manager: Arc>, + datastore: Arc, + recovery: Arc, + session_manager: Arc>, } impl State { pub fn new(rpc_path: PathBuf, promise_secret: &[u8; 32], collect_timeout_secs: u64) -> Self { - let api = Arc::new(ClnApiRpc::new(rpc_path.clone())); + let rpc = ClnRpcClient::new(rpc_path.clone()); let sender = ClnSender::new(rpc_path); - let lsps2_handler = Arc::new(Lsps2ServiceHandler::new(api.clone(), promise_secret)); + let datastore = Arc::new(ClnDatastore::new(rpc.clone())); + let blockheight = Arc::new(ClnBlockheight::new(rpc.clone())); + let policy = Arc::new(ClnPolicyProvider::new(rpc.clone())); + let executor = Arc::new(ClnActionExecutor::new(rpc.clone())); + let recovery = Arc::new(ClnRecoveryProvider::new(rpc)); + let lsps2_handler = Arc::new(Lsps2ServiceHandler::new(datastore.clone(), blockheight, policy, promise_secret)); let lsps_service = Arc::new(LspsService::builder().with_protocol(lsps2_handler).build()); let session_manager = Arc::new(SessionManager::new( - api.clone(), - api.clone(), + datastore.clone(), + executor, SessionConfig { collect_timeout_secs, ..SessionConfig::default() @@ -74,7 +82,8 @@ impl State { lsps_service, sender, lsps2_enabled: true, - api, + datastore, + recovery, session_manager, } } @@ -151,7 +160,7 @@ async fn main() -> Result<(), anyhow::Error> { let state = State::new(rpc_path, &secret, collect_timeout_secs); // Recover in-flight sessions before processing replayed HTLCs - let recovery: Arc = state.api.clone(); + let recovery: Arc = state.recovery.clone(); if let Err(e) = state.session_manager.recover(recovery).await { warn!("session recovery failed: {e}"); } @@ -198,8 +207,8 @@ async fn handle_htlc_inner( let req: HtlcAcceptedRequest = serde_json::from_value(v)?; - let short_channel_id = match req.onion.short_channel_id { - Some(scid) => scid, + let short_channel_id: ShortChannelId = match req.onion.short_channel_id { + Some(scid) => scid.into(), None => { trace!("We are the destination of the HTLC, continue."); return Ok(json_continue()); @@ -207,7 +216,7 @@ async fn handle_htlc_inner( }; // Decide path: look up buy request to check for MPP. - let ds_rec = match p.state().api.get_buy_request(&short_channel_id).await { + let ds_rec = match p.state().datastore.get_buy_request(&short_channel_id).await { Ok(rec) => rec, Err(_) => { trace!("SCID not ours, continue."); @@ -218,7 +227,7 @@ async fn handle_htlc_inner( if Utc::now() >= ds_rec.opening_fee_params.valid_until { let _ = p .state() - .api + .datastore .finalize_session(&short_channel_id, SessionOutcome::Timeout) .await; return Ok(json_fail(UNKNOWN_NEXT_PEER)); From df7c1329ea39c86730beed91acd4fd608f9a600f Mon Sep 17 00:00:00 2001 From: Peter Neuroth Date: Thu, 19 Mar 2026 00:13:01 +0100 Subject: [PATCH 09/11] plugins(lsps2): simplify actor and session manager Merge BlockheightProvider into Lsps2PolicyProvider, extract check_cltv_timeout helper in the session FSM, flatten recovery branching in SessionManager, simplify the actor loop with convert_input and tokio::select!, and remove the unused CollectTimeout ActorInput variant. --- plugins/lsps-plugin/src/cln_adapters/mod.rs | 2 +- plugins/lsps-plugin/src/cln_adapters/rpc.rs | 30 +- plugins/lsps-plugin/src/core/lsps2/actor.rs | 256 ++++++++---------- plugins/lsps-plugin/src/core/lsps2/manager.rs | 146 +++++----- .../lsps-plugin/src/core/lsps2/provider.rs | 9 +- plugins/lsps-plugin/src/core/lsps2/service.rs | 41 ++- plugins/lsps-plugin/src/core/lsps2/session.rs | 78 +++--- plugins/lsps-plugin/src/service.rs | 5 +- 8 files changed, 252 insertions(+), 315 deletions(-) diff --git a/plugins/lsps-plugin/src/cln_adapters/mod.rs b/plugins/lsps-plugin/src/cln_adapters/mod.rs index 821b16f474b0..cb162acdcf80 100644 --- a/plugins/lsps-plugin/src/cln_adapters/mod.rs +++ b/plugins/lsps-plugin/src/cln_adapters/mod.rs @@ -5,6 +5,6 @@ pub mod state; pub mod types; pub use rpc::{ - ClnActionExecutor, ClnBlockheight, ClnDatastore, ClnPolicyProvider, ClnRecoveryProvider, + ClnActionExecutor, ClnDatastore, ClnPolicyProvider, ClnRecoveryProvider, ClnRpcClient, }; diff --git a/plugins/lsps-plugin/src/cln_adapters/rpc.rs b/plugins/lsps-plugin/src/cln_adapters/rpc.rs index 5b4b814dc421..87b98d142831 100644 --- a/plugins/lsps-plugin/src/cln_adapters/rpc.rs +++ b/plugins/lsps-plugin/src/cln_adapters/rpc.rs @@ -2,7 +2,7 @@ use crate::{ core::lsps2::{ actor::ActionExecutor, provider::{ - Blockheight, BlockheightProvider, ChannelRecoveryInfo, DatastoreProvider, + ChannelRecoveryInfo, DatastoreProvider, ForwardActivity, Lsps2PolicyProvider, RecoveryProvider, }, }, @@ -661,23 +661,23 @@ impl DatastoreProvider for ClnDatastore { } // --------------------------------------------------------------------------- -// ClnBlockheight — implements BlockheightProvider +// ClnPolicyProvider — implements Lsps2PolicyProvider // --------------------------------------------------------------------------- #[derive(Clone)] -pub struct ClnBlockheight { +pub struct ClnPolicyProvider { rpc: ClnRpcClient, } -impl ClnBlockheight { +impl ClnPolicyProvider { pub fn new(rpc: ClnRpcClient) -> Self { Self { rpc } } } #[async_trait] -impl BlockheightProvider for ClnBlockheight { - async fn get_blockheight(&self) -> Result { +impl Lsps2PolicyProvider for ClnPolicyProvider { + async fn get_blockheight(&self) -> Result { let mut rpc = self.rpc.create_rpc().await?; let info = rpc .call_typed(&GetinfoRequest {}) @@ -686,25 +686,7 @@ impl BlockheightProvider for ClnBlockheight { .with_context(|| "calling getinfo")?; Ok(info.blockheight) } -} - -// --------------------------------------------------------------------------- -// ClnPolicyProvider — implements Lsps2PolicyProvider -// --------------------------------------------------------------------------- - -#[derive(Clone)] -pub struct ClnPolicyProvider { - rpc: ClnRpcClient, -} - -impl ClnPolicyProvider { - pub fn new(rpc: ClnRpcClient) -> Self { - Self { rpc } - } -} -#[async_trait] -impl Lsps2PolicyProvider for ClnPolicyProvider { async fn get_info( &self, request: &Lsps2PolicyGetInfoRequest, diff --git a/plugins/lsps-plugin/src/core/lsps2/actor.rs b/plugins/lsps-plugin/src/core/lsps2/actor.rs index 13a39d458c5b..8758d3ac045e 100644 --- a/plugins/lsps-plugin/src/core/lsps2/actor.rs +++ b/plugins/lsps-plugin/src/core/lsps2/actor.rs @@ -15,10 +15,7 @@ use bitcoin::hashes::sha256::Hash as PaymentHash; use bitcoin::hashes::Hash; use log::{debug, warn}; use std::{collections::HashMap, sync::Arc, time::Duration}; -use tokio::{ - sync::{mpsc, oneshot}, - task::JoinHandle, -}; +use tokio::sync::{mpsc, oneshot}; #[derive(Debug, Clone, PartialEq, Eq)] pub enum HtlcResponse { @@ -38,7 +35,6 @@ enum ActorInput { part: PaymentPart, reply_tx: oneshot::Sender, }, - CollectTimeout, ChannelReady { channel_id: String, funding_psbt: String, @@ -129,8 +125,8 @@ pub struct SessionActor { entry: DatastoreEntry, inbox: mpsc::Receiver, pending_htlcs: HashMap>, - collect_timeout_handle: Option>, - channel_poll_handle: Option>, + collect_fired: bool, + channel_poll_handle: Option>, self_send: mpsc::Sender, executor: A, peer_id: String, @@ -159,7 +155,7 @@ impl Option { + match input { + ActorInput::AddPart { part, reply_tx } => { + let htlc_id = part.htlc_id; + self.pending_htlcs.insert(htlc_id, reply_tx); + Some(SessionInput::AddPart { part }) + } + ActorInput::ChannelReady { + channel_id, + funding_psbt, + } => { + self.entry.channel_id = Some(channel_id.clone()); + self.entry.funding_psbt = Some(funding_psbt.clone()); + if let Err(e) = self.datastore.save_session(&self.scid, &self.entry).await { + warn!("save_session failed on ChannelReady: {e}"); + } + Some(SessionInput::ChannelReady { + channel_id, + funding_psbt, + }) + } + ActorInput::FundingFailed => Some(SessionInput::FundingFailed), + ActorInput::PaymentSettled { + preimage, + updated_index, + } => { + if let Some(index) = updated_index { + self.entry.forwards_updated_index = Some(index); + } + if let Some(ref pre) = preimage { + self.entry.preimage = Some(pre.clone()); + } + if updated_index.is_some() || preimage.is_some() { + if let Err(e) = self.datastore.save_session(&self.scid, &self.entry).await { + warn!("save_session failed on PaymentSettled: {e}"); + } + } + Some(SessionInput::PaymentSettled) + } + ActorInput::PaymentFailed { updated_index } => { + if let Some(index) = updated_index { + self.entry.forwards_updated_index = Some(index); + if let Err(e) = self.datastore.save_session(&self.scid, &self.entry).await { + warn!("save_session failed on PaymentFailed: {e}"); + } + } + Some(SessionInput::PaymentFailed) + } + ActorInput::FundingBroadcasted { txid } => { + self.entry.funding_txid = Some(txid); + if let Err(e) = self.datastore.save_session(&self.scid, &self.entry).await { + warn!("save_session failed on FundingBroadcasted: {e}"); + } + Some(SessionInput::FundingBroadcasted) + } + ActorInput::NewBlock { height } => Some(SessionInput::NewBlock { height }), + ActorInput::ChannelClosed { channel_id } => { + Some(SessionInput::ChannelClosed { channel_id }) + } + } } - fn cancel_collect_timeout(&mut self) { - if let Some(handle) = self.collect_timeout_handle.take() { - handle.abort(); + /// Apply a session input to the FSM and execute resulting actions. + /// Returns `true` if the session reached a terminal state. + fn apply_and_execute(&mut self, input: SessionInput) -> bool { + match self.session.apply(input) { + Ok(result) => { + self.dispatch_events(result.events); + for action in result.actions { + self.execute_action(action); + } + self.session.is_terminal() + } + Err(e) => { + warn!("session FSM error: {e}"); + if self.session.is_terminal() { + self.release_pending_htlcs(); + true + } else { + false + } + } } } @@ -278,93 +344,31 @@ impl { - let htlc_id = part.htlc_id; - self.pending_htlcs.insert(htlc_id, reply_tx); - SessionInput::AddPart { part } - } - ActorInput::CollectTimeout => SessionInput::CollectTimeout, - ActorInput::ChannelReady { - channel_id, - funding_psbt, - } => { - self.entry.channel_id = Some(channel_id.clone()); - self.entry.funding_psbt = Some(funding_psbt.clone()); - if let Err(e) = self.datastore.save_session(&self.scid, &self.entry).await { - warn!("save_session failed on ChannelReady: {e}"); - } - SessionInput::ChannelReady { - channel_id, - funding_psbt, - } - } - ActorInput::FundingFailed => SessionInput::FundingFailed, - ActorInput::PaymentSettled { - preimage, - updated_index, - } => { - if let Some(index) = updated_index { - self.entry.forwards_updated_index = Some(index); - } - if let Some(ref pre) = preimage { - self.entry.preimage = Some(pre.clone()); - } - if updated_index.is_some() || preimage.is_some() { - if let Err(e) = self.datastore.save_session(&self.scid, &self.entry).await { - warn!("save_session failed on PaymentSettled: {e}"); - } - } - SessionInput::PaymentSettled - } - ActorInput::PaymentFailed { updated_index } => { - if let Some(index) = updated_index { - self.entry.forwards_updated_index = Some(index); - if let Err(e) = self.datastore.save_session(&self.scid, &self.entry).await { - warn!("save_session failed on PaymentFailed: {e}"); - } - } - SessionInput::PaymentFailed - } - ActorInput::FundingBroadcasted { txid } => { - self.entry.funding_txid = Some(txid); - if let Err(e) = self.datastore.save_session(&self.scid, &self.entry).await { - warn!("save_session failed on FundingBroadcasted: {e}"); - } - SessionInput::FundingBroadcasted - } - ActorInput::NewBlock { height } => SessionInput::NewBlock { height }, - ActorInput::ChannelClosed { channel_id } => { - SessionInput::ChannelClosed { channel_id } - } - }; - - match self.session.apply(input) { - Ok(result) => { - self.dispatch_events(result.events); + let collect_deadline = tokio::time::sleep( + Duration::from_secs(self.collect_timeout_secs), + ); + tokio::pin!(collect_deadline); - for action in result.actions { - self.execute_action(action); - } - - if self.session.is_terminal() { + loop { + tokio::select! { + input = self.inbox.recv() => { + let Some(input) = input else { break }; + let Some(session_input) = self.convert_input(input).await else { + continue; + }; + if self.apply_and_execute(session_input) { break; } } - Err(e) => { - warn!("session FSM error: {e}"); - if self.session.is_terminal() { - self.release_pending_htlcs(); + _ = &mut collect_deadline, if !self.collect_fired => { + self.collect_fired = true; + if self.apply_and_execute(SessionInput::CollectTimeout) { break; } } } } - // We exited the loop, just continue all held HTLCs and let the handler - // decide. self.release_pending_htlcs(); Self::finalize(&self.session, &self.datastore, self.scid).await; } @@ -462,59 +466,21 @@ impl { - let session_input = match actor_input { - ActorInput::PaymentSettled { - preimage, - updated_index, - } => { - if let Some(index) = updated_index { - self.entry.forwards_updated_index = Some(index); - } - if let Some(ref pre) = preimage { - self.entry.preimage = Some(pre.clone()); - } - if updated_index.is_some() || preimage.is_some() { - if let Err(e) = self.datastore.save_session(&self.scid, &self.entry).await { - warn!("save_session failed on PaymentSettled: {e}"); - } - } - SessionInput::PaymentSettled - } - ActorInput::PaymentFailed { updated_index } => { - if let Some(index) = updated_index { - self.entry.forwards_updated_index = Some(index); - if let Err(e) = self.datastore.save_session(&self.scid, &self.entry).await { - warn!("save_session failed on PaymentFailed: {e}"); - } - } - SessionInput::PaymentFailed - } - ActorInput::FundingBroadcasted { txid } => { - self.entry.funding_txid = Some(txid); - if let Err(e) = self.datastore.save_session(&self.scid, &self.entry).await { - warn!("save_session failed on FundingBroadcasted: {e}"); - } - SessionInput::FundingBroadcasted + // Only process the three shared persistence arms + let session_input = match &actor_input { + ActorInput::PaymentSettled { .. } + | ActorInput::PaymentFailed { .. } + | ActorInput::FundingBroadcasted { .. } => { + self.convert_input(actor_input).await } _ => continue, }; - match self.session.apply(session_input) { - Ok(result) => { - self.dispatch_events(result.events); - for action in result.actions { - self.execute_action(action); - } - } - Err(e) => { - warn!("FSM error in recovered session: {e}"); + if let Some(input) = session_input { + if self.apply_and_execute(input) { break; } } - - if self.session.is_terminal() { - break; - } } None => break, } @@ -540,9 +506,9 @@ impl { - // First time forwarding HTLCs, we cancel the collect timeout - // and start polling the channel for closure: - self.cancel_collect_timeout(); + // First time forwarding HTLCs, we mark the collect timeout as + // fired and start polling the channel for closure: + self.collect_fired = true; self.start_channel_poll(channel_id.clone()); for part in &parts { if let Some(reply_tx) = self.pending_htlcs.remove(&part.htlc_id) { diff --git a/plugins/lsps-plugin/src/core/lsps2/manager.rs b/plugins/lsps-plugin/src/core/lsps2/manager.rs index 3fc2dc0a2a78..4389945e2f49 100644 --- a/plugins/lsps-plugin/src/core/lsps2/manager.rs +++ b/plugins/lsps-plugin/src/core/lsps2/manager.rs @@ -4,7 +4,7 @@ use super::session::{PaymentPart, Session}; use crate::core::lsps2::actor::SessionActor; use crate::core::lsps2::event_sink::EventSink; use crate::proto::lsps0::ShortChannelId; -use crate::proto::lsps2::SessionOutcome; +use crate::proto::lsps2::{DatastoreEntry, SessionOutcome}; pub use bitcoin::hashes::sha256::Hash as PaymentHash; use chrono::Utc; use log::{debug, warn}; @@ -61,80 +61,92 @@ impl let entries = self.datastore.list_active_sessions().await?; for (scid, entry) in entries { - match (&entry.channel_id, &entry.funding_psbt) { - (None, _) => { - if entry.opening_fee_params.valid_until < Utc::now() { - self.datastore - .finalize_session(&scid, SessionOutcome::Timeout) - .await?; - } - } + if let Some(handle) = self.recover_session(scid, entry, &recovery).await? { + self.recovery_handles.lock().await.push(handle); + } + } - (Some(channel_id), Some(funding_psbt)) => { - let info = recovery.get_channel_recovery_info(channel_id).await?; - - if !info.exists { - self.datastore - .finalize_session(&scid, SessionOutcome::Abandoned) - .await?; - continue; - } - - let activity = recovery.get_forward_activity(channel_id).await?; - - match activity { - ForwardActivity::NoForwards => { - recovery - .close_and_unreserve(channel_id, funding_psbt) - .await?; - let mut entry = entry; - entry.channel_id = None; - entry.funding_psbt = None; - entry.funding_txid = None; - self.datastore.save_session(&scid, &entry).await?; - } - ForwardActivity::AllFailed => { - self.datastore - .finalize_session(&scid, SessionOutcome::Abandoned) - .await?; - } - ForwardActivity::Offered | ForwardActivity::Settled => { - let (session, initial_actions) = Session::recover( - channel_id.clone(), - funding_psbt.clone(), - entry.preimage.clone(), - entry.opening_fee_params.clone(), - ); - - let handle = - SessionActor::spawn_recovered_session_actor( - session, - entry, - initial_actions, - channel_id, - self.executor.clone(), - scid, - self.datastore.clone(), - recovery.clone(), - forwards_updated_index, - self.event_sink.clone(), - ); - - self.recovery_handles.lock().await.push(handle); - } - } - } + Ok(()) + } - _ => { - warn!("inconsistent datastore entry for scid={scid}, finalizing as Failed"); + async fn recover_session( + &self, + scid: ShortChannelId, + entry: DatastoreEntry, + recovery: &Arc, + ) -> anyhow::Result> { + let (channel_id, funding_psbt) = match (&entry.channel_id, &entry.funding_psbt) { + (None, _) => { + if entry.opening_fee_params.valid_until < Utc::now() { self.datastore - .finalize_session(&scid, SessionOutcome::Failed) + .finalize_session(&scid, SessionOutcome::Timeout) .await?; } + return Ok(None); } + (Some(cid), Some(psbt)) => (cid.clone(), psbt.clone()), + _ => { + warn!("inconsistent datastore entry for scid={scid}, finalizing as Failed"); + self.datastore + .finalize_session(&scid, SessionOutcome::Failed) + .await?; + return Ok(None); + } + }; + + let info = recovery.get_channel_recovery_info(&channel_id).await?; + if !info.exists { + self.datastore + .finalize_session(&scid, SessionOutcome::Abandoned) + .await?; + return Ok(None); } - Ok(()) + let activity = recovery.get_forward_activity(&channel_id).await?; + + match activity { + ForwardActivity::NoForwards => { + recovery + .close_and_unreserve(&channel_id, &funding_psbt) + .await?; + let mut entry = entry; + entry.channel_id = None; + entry.funding_psbt = None; + entry.funding_txid = None; + self.datastore.save_session(&scid, &entry).await?; + Ok(None) + } + ForwardActivity::AllFailed => { + self.datastore + .finalize_session(&scid, SessionOutcome::Abandoned) + .await?; + Ok(None) + } + ForwardActivity::Offered | ForwardActivity::Settled => { + let forwards_updated_index = entry.forwards_updated_index; + let (session, initial_actions) = Session::recover( + channel_id.clone(), + funding_psbt.clone(), + entry.preimage.clone(), + entry.opening_fee_params.clone(), + ); + + let handle = SessionActor::spawn_recovered_session_actor( + session, + entry, + initial_actions, + channel_id, + self.executor.clone(), + scid, + self.datastore.clone(), + recovery.clone(), + forwards_updated_index, + self.event_sink.clone(), + ); + + Ok(Some(handle)) + } + } } pub async fn on_part( diff --git a/plugins/lsps-plugin/src/core/lsps2/provider.rs b/plugins/lsps-plugin/src/core/lsps2/provider.rs index fef5c44da8bc..04f9dfaa3f0a 100644 --- a/plugins/lsps-plugin/src/core/lsps2/provider.rs +++ b/plugins/lsps-plugin/src/core/lsps2/provider.rs @@ -10,13 +10,6 @@ use crate::proto::{ }, }; -pub type Blockheight = u32; - -#[async_trait] -pub trait BlockheightProvider: Send + Sync { - async fn get_blockheight(&self) -> Result; -} - #[async_trait] pub trait DatastoreProvider: Send + Sync { async fn store_buy_request( @@ -83,6 +76,8 @@ pub trait RecoveryProvider: Send + Sync { #[async_trait] pub trait Lsps2PolicyProvider: Send + Sync { + async fn get_blockheight(&self) -> Result; + async fn get_info( &self, request: &Lsps2PolicyGetInfoRequest, diff --git a/plugins/lsps-plugin/src/core/lsps2/service.rs b/plugins/lsps-plugin/src/core/lsps2/service.rs index e38b0b23480f..8b980a35b1a7 100644 --- a/plugins/lsps-plugin/src/core/lsps2/service.rs +++ b/plugins/lsps-plugin/src/core/lsps2/service.rs @@ -1,6 +1,6 @@ use crate::{ core::{ - lsps2::provider::{BlockheightProvider, DatastoreProvider, Lsps2PolicyProvider}, + lsps2::provider::{DatastoreProvider, Lsps2PolicyProvider}, router::JsonRpcRouterBuilder, server::LspsProtocol, }, @@ -49,23 +49,20 @@ where } } -pub struct Lsps2ServiceHandler { +pub struct Lsps2ServiceHandler { pub datastore: Arc, - pub blockheight: Arc, pub policy: Arc

, pub promise_secret: [u8; 32], } -impl Lsps2ServiceHandler { +impl Lsps2ServiceHandler { pub fn new( datastore: Arc, - blockheight: Arc, policy: Arc

, promise_secret: &[u8; 32], ) -> Self { Lsps2ServiceHandler { datastore, - blockheight, policy, promise_secret: promise_secret.to_owned(), } @@ -73,10 +70,9 @@ impl Lsps2ServiceHandler { } #[async_trait] -impl Lsps2Handler for Lsps2ServiceHandler +impl Lsps2Handler for Lsps2ServiceHandler where D: DatastoreProvider + 'static, - B: BlockheightProvider + 'static, P: Lsps2PolicyProvider + 'static, { async fn handle_get_info( @@ -120,7 +116,7 @@ where // Generate a tmp scid to identify jit channel request in htlc. let blockheight = self - .blockheight + .policy .get_blockheight() .await .map_err(|_| RpcError::internal_error("internal error"))?; @@ -301,6 +297,16 @@ mod tests { #[async_trait] impl Lsps2PolicyProvider for MockApi { + async fn get_blockheight(&self) -> AnyResult { + if *self.blockheight_error.lock().unwrap() { + return Err(anyhow!("blockheight error")); + } + self.blockheight + .lock() + .unwrap() + .ok_or_else(|| anyhow!("no blockheight set")) + } + async fn get_info( &self, _request: &Lsps2PolicyGetInfoRequest, @@ -333,19 +339,6 @@ mod tests { } } - #[async_trait] - impl BlockheightProvider for MockApi { - async fn get_blockheight(&self) -> AnyResult { - if *self.blockheight_error.lock().unwrap() { - return Err(anyhow!("blockheight error")); - } - self.blockheight - .lock() - .unwrap() - .ok_or_else(|| anyhow!("no blockheight set")) - } - } - #[async_trait] impl DatastoreProvider for MockApi { async fn store_buy_request( @@ -418,9 +411,9 @@ mod tests { } } - fn handler(api: MockApi) -> Lsps2ServiceHandler { + fn handler(api: MockApi) -> Lsps2ServiceHandler { let api = Arc::new(api); - Lsps2ServiceHandler::new(api.clone(), api.clone(), api, &test_secret()) + Lsps2ServiceHandler::new(api.clone(), api, &test_secret()) } #[tokio::test] diff --git a/plugins/lsps-plugin/src/core/lsps2/session.rs b/plugins/lsps-plugin/src/core/lsps2/session.rs index 38952d55e88d..ea191a9f98d7 100644 --- a/plugins/lsps-plugin/src/core/lsps2/session.rs +++ b/plugins/lsps-plugin/src/core/lsps2/session.rs @@ -311,6 +311,35 @@ impl Session { } } + fn check_cltv_timeout( + &mut self, + parts: &[PaymentPart], + height: u32, + ) -> Option { + let min = cltv_min(parts)?; + if height > min { + self.state = SessionState::Failed; + Some(ApplyResult { + actions: vec![ + SessionAction::FailHtlcs { + failure_code: TEMPORARY_CHANNEL_FAILURE, + }, + SessionAction::Disconnect, + SessionAction::FailSession, + ], + events: vec![ + SessionEvent::UnsafeHtlcTimeout { + height, + cltv_min: min, + }, + SessionEvent::SessionFailed, + ], + }) + } else { + None + } + } + pub fn apply(&mut self, input: SessionInput) -> Result { match (&mut self.state, input) { // @@ -440,28 +469,8 @@ impl Session { }) } (SessionState::Collecting { parts }, SessionInput::NewBlock { height }) => { - if let Some(min) = cltv_min(parts) { - if height > min { - self.state = SessionState::Failed; - return Ok(ApplyResult { - actions: vec![ - SessionAction::FailHtlcs { - failure_code: TEMPORARY_CHANNEL_FAILURE, - }, - SessionAction::FailSession, - ], - events: vec![ - SessionEvent::UnsafeHtlcTimeout { - height, - cltv_min: min, - }, - SessionEvent::SessionFailed, - ], - }); - } - } - // No parts or height <= cltv_min: stay collecting. - Ok(ApplyResult::default()) + let parts = parts.clone(); + Ok(self.check_cltv_timeout(&parts, height).unwrap_or_default()) } ( SessionState::Collecting { .. }, @@ -608,28 +617,8 @@ impl Session { SessionState::AwaitingChannelReady { parts, .. }, SessionInput::NewBlock { height }, ) => { - if let Some(min) = cltv_min(parts) { - if height > min { - self.state = SessionState::Failed; - return Ok(ApplyResult { - actions: vec![ - SessionAction::FailHtlcs { - failure_code: TEMPORARY_CHANNEL_FAILURE, - }, - SessionAction::Disconnect, - SessionAction::FailSession, - ], - events: vec![ - SessionEvent::UnsafeHtlcTimeout { - height, - cltv_min: min, - }, - SessionEvent::SessionFailed, - ], - }); - } - } - Ok(ApplyResult::default()) + let parts = parts.clone(); + Ok(self.check_cltv_timeout(&parts, height).unwrap_or_default()) } ( SessionState::AwaitingChannelReady { .. }, @@ -1819,6 +1808,7 @@ mod tests { SessionAction::FailHtlcs { failure_code: TEMPORARY_CHANNEL_FAILURE, }, + SessionAction::Disconnect, SessionAction::FailSession, ] ); diff --git a/plugins/lsps-plugin/src/service.rs b/plugins/lsps-plugin/src/service.rs index 901ecc3529ed..d11b17868992 100644 --- a/plugins/lsps-plugin/src/service.rs +++ b/plugins/lsps-plugin/src/service.rs @@ -4,7 +4,7 @@ use chrono::Utc; use cln_lsps::{ cln_adapters::{ hooks::service_custommsg_hook, - rpc::{ClnActionExecutor, ClnBlockheight, ClnDatastore, ClnPolicyProvider, ClnRecoveryProvider, ClnRpcClient}, + rpc::{ClnActionExecutor, ClnDatastore, ClnPolicyProvider, ClnRecoveryProvider, ClnRpcClient}, sender::ClnSender, state::ServiceState, types::HtlcAcceptedRequest, }, @@ -63,11 +63,10 @@ impl State { let rpc = ClnRpcClient::new(rpc_path.clone()); let sender = ClnSender::new(rpc_path); let datastore = Arc::new(ClnDatastore::new(rpc.clone())); - let blockheight = Arc::new(ClnBlockheight::new(rpc.clone())); let policy = Arc::new(ClnPolicyProvider::new(rpc.clone())); let executor = Arc::new(ClnActionExecutor::new(rpc.clone())); let recovery = Arc::new(ClnRecoveryProvider::new(rpc)); - let lsps2_handler = Arc::new(Lsps2ServiceHandler::new(datastore.clone(), blockheight, policy, promise_secret)); + let lsps2_handler = Arc::new(Lsps2ServiceHandler::new(datastore.clone(), policy, promise_secret)); let lsps_service = Arc::new(LspsService::builder().with_protocol(lsps2_handler).build()); let session_manager = Arc::new(SessionManager::new( datastore.clone(), From be5c06bdd21b4c62585b3dab6f0b4c88c12291c2 Mon Sep 17 00:00:00 2001 From: Peter Neuroth Date: Thu, 19 Mar 2026 13:24:24 +0100 Subject: [PATCH 10/11] plugins(lsps2): make collect timeout dev config We actually only use this in tests Signed-off-by: Peter Neuroth --- plugins/lsps-plugin/src/service.rs | 31 +++++++++++++++--------------- tests/test_cln_lsps.py | 2 +- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/plugins/lsps-plugin/src/service.rs b/plugins/lsps-plugin/src/service.rs index d11b17868992..58a22c4637a1 100644 --- a/plugins/lsps-plugin/src/service.rs +++ b/plugins/lsps-plugin/src/service.rs @@ -4,8 +4,11 @@ use chrono::Utc; use cln_lsps::{ cln_adapters::{ hooks::service_custommsg_hook, - rpc::{ClnActionExecutor, ClnDatastore, ClnPolicyProvider, ClnRecoveryProvider, ClnRpcClient}, - sender::ClnSender, state::ServiceState, + rpc::{ + ClnActionExecutor, ClnDatastore, ClnPolicyProvider, ClnRecoveryProvider, ClnRpcClient, + }, + sender::ClnSender, + state::ServiceState, types::HtlcAcceptedRequest, }, core::{ @@ -14,8 +17,8 @@ use cln_lsps::{ event_sink::NoopEventSink, manager::{PaymentHash, SessionConfig, SessionManager}, provider::{DatastoreProvider, RecoveryProvider}, - session::PaymentPart, service::Lsps2ServiceHandler, + session::PaymentPart, }, server::LspsService, tlv::{TlvStream, TLV_FORWARD_AMT}, @@ -43,7 +46,7 @@ pub const OPTION_PROMISE_SECRET: options::StringConfigOption = pub const OPTION_COLLECT_TIMEOUT: options::DefaultIntegerConfigOption = options::ConfigOption::new_i64_with_default( - "experimental-lsps2-collect-timeout", + "dev-lsps2-collect-timeout", 90, "Timeout in seconds for collecting MPP parts (default: 90)", ); @@ -66,7 +69,11 @@ impl State { let policy = Arc::new(ClnPolicyProvider::new(rpc.clone())); let executor = Arc::new(ClnActionExecutor::new(rpc.clone())); let recovery = Arc::new(ClnRecoveryProvider::new(rpc)); - let lsps2_handler = Arc::new(Lsps2ServiceHandler::new(datastore.clone(), policy, promise_secret)); + let lsps2_handler = Arc::new(Lsps2ServiceHandler::new( + datastore.clone(), + policy, + promise_secret, + )); let lsps_service = Arc::new(LspsService::builder().with_protocol(lsps2_handler).build()); let session_manager = Arc::new(SessionManager::new( datastore.clone(), @@ -240,8 +247,7 @@ async fn handle_session_htlc( req: &HtlcAcceptedRequest, scid: ShortChannelId, ) -> Result { - let payment_hash = - PaymentHash::from_byte_array(req.htlc.payment_hash.as_slice().try_into()?); + let payment_hash = PaymentHash::from_byte_array(req.htlc.payment_hash.as_slice().try_into()?); let part = PaymentPart { htlc_id: req.htlc.id, amount_msat: Msat::from_msat(req.htlc.amount_msat.msat()), @@ -297,10 +303,7 @@ fn session_response_to_json( } } -async fn on_forward_event( - p: Plugin, - v: serde_json::Value, -) -> Result<(), anyhow::Error> { +async fn on_forward_event(p: Plugin, v: serde_json::Value) -> Result<(), anyhow::Error> { let event = match v.get("forward_event") { Some(e) => e, None => return Ok(()), @@ -356,10 +359,7 @@ async fn on_forward_event( Ok(()) } -async fn on_block_added( - p: Plugin, - v: serde_json::Value, -) -> Result<(), anyhow::Error> { +async fn on_block_added(p: Plugin, v: serde_json::Value) -> Result<(), anyhow::Error> { let height = match v .get("block_added") .and_then(|b| b.get("height")) @@ -396,4 +396,3 @@ fn json_fail(failure_code: &str) -> serde_json::Value { "failure_message": failure_code }) } - diff --git a/tests/test_cln_lsps.py b/tests/test_cln_lsps.py index c220edaeb32b..f87294c1f520 100644 --- a/tests/test_cln_lsps.py +++ b/tests/test_cln_lsps.py @@ -13,7 +13,7 @@ LSP_OPTS = { "experimental-lsps2-service": None, "experimental-lsps2-promise-secret": "0" * 64, - "experimental-lsps2-collect-timeout": 5, + "dev-lsps2-collect-timeout": 5, "plugin": POLICY_PLUGIN, "fee-base": 0, "fee-per-satoshi": 0, From 73b18ed3427b48ddff40d47179b38eae125af2d7 Mon Sep 17 00:00:00 2001 From: Peter Neuroth Date: Thu, 19 Mar 2026 19:20:09 +0100 Subject: [PATCH 11/11] plugins(lsps2): route recovered sessions through forward_event path After restart, recovered session actors were stored in a separate recovery_handles Vec, unreachable by the forward_event notification path that routes via the sessions HashMap. This caused intermittent CI failures where on_payment_settled could not find the session and the internal forward-monitoring loop failed to detect settlement. Register recovered sessions in the sessions HashMap keyed by payment_hash so forward_event notifications reach them directly. For already-settled forwards, recover into Broadcasting state so the actor self-drives to completion without needing forward_event re-delivery. Remove the now-redundant internal polling loop (get_forward_activity + wait_for_forward_resolution). --- plugins/lsps-plugin/src/cln_adapters/rpc.rs | 60 +----- plugins/lsps-plugin/src/core/lsps2/actor.rs | 92 +------- plugins/lsps-plugin/src/core/lsps2/manager.rs | 198 ++++++++++++++++-- .../lsps-plugin/src/core/lsps2/provider.rs | 9 - 4 files changed, 187 insertions(+), 172 deletions(-) diff --git a/plugins/lsps-plugin/src/cln_adapters/rpc.rs b/plugins/lsps-plugin/src/cln_adapters/rpc.rs index 87b98d142831..aa841779f15c 100644 --- a/plugins/lsps-plugin/src/cln_adapters/rpc.rs +++ b/plugins/lsps-plugin/src/cln_adapters/rpc.rs @@ -26,9 +26,8 @@ use cln_rpc::{ FundchannelCompleteRequest, FundchannelStartRequest, FundpsbtRequest, GetinfoRequest, ListdatastoreRequest, ListforwardsIndex, ListforwardsRequest, ListpeerchannelsRequest, SendpsbtRequest, SignpsbtRequest, UnreserveinputsRequest, - WaitIndexname, WaitRequest, WaitSubsystem, }, - responses::{ListdatastoreResponse, ListforwardsForwardsStatus, WaitForwardsStatus}, + responses::{ListdatastoreResponse, ListforwardsForwardsStatus}, }, primitives::{Amount, AmountOrAll, ChannelState, Feerate, Sha256}, ClnRpc, @@ -845,63 +844,6 @@ impl RecoveryProvider for ClnRecoveryProvider { } } - async fn wait_for_forward_resolution( - &self, - channel_id: &str, - from_index: u64, - ) -> Result<(ForwardActivity, u64)> { - // Get the scid for this channel so we can match wait responses. - let scid = self.rpc.get_channel_scid(channel_id).await?; - - let mut next_index = from_index + 1; - loop { - let mut rpc = self.rpc.create_rpc().await?; - let wait_res = rpc - .call_typed(&WaitRequest { - subsystem: WaitSubsystem::FORWARDS, - indexname: WaitIndexname::UPDATED, - nextvalue: next_index, - }) - .await - .with_context(|| { - format!("calling wait for channel_id={channel_id} at index={next_index}") - })?; - - let new_index = wait_res.updated.unwrap_or(next_index); - - // Check if this update is for our channel. - let is_our_channel = match (&scid, &wait_res.forwards) { - (Some(our_scid), Some(fwd)) => fwd - .out_channel - .as_ref() - .map(|c| c == our_scid) - .unwrap_or(false), - _ => false, - }; - - if is_our_channel { - if let Some(fwd) = &wait_res.forwards { - match fwd.status { - Some(WaitForwardsStatus::SETTLED) => { - return Ok((ForwardActivity::Settled, new_index)); - } - Some(WaitForwardsStatus::OFFERED) => { - return Ok((ForwardActivity::Offered, new_index)); - } - Some(WaitForwardsStatus::FAILED) - | Some(WaitForwardsStatus::LOCAL_FAILED) => { - // Check full history to decide AllFailed vs Active. - let activity = self.get_forward_activity(channel_id).await?; - return Ok((activity, new_index)); - } - None => {} - } - } - } - - next_index = new_index + 1; - } - } } // --------------------------------------------------------------------------- diff --git a/plugins/lsps-plugin/src/core/lsps2/actor.rs b/plugins/lsps-plugin/src/core/lsps2/actor.rs index 8758d3ac045e..b45a3cc7ea98 100644 --- a/plugins/lsps-plugin/src/core/lsps2/actor.rs +++ b/plugins/lsps-plugin/src/core/lsps2/actor.rs @@ -1,7 +1,7 @@ use crate::{ core::lsps2::{ event_sink::{EventSink, SessionEventEnvelope}, - provider::{DatastoreProvider, ForwardActivity, RecoveryProvider}, + provider::DatastoreProvider, session::{PaymentPart, Session, SessionAction, SessionEvent, SessionInput}, }, proto::{ @@ -173,12 +173,9 @@ impl, - channel_id: String, executor: A, scid: ShortChannelId, datastore: D, - recovery: Arc, - forwards_updated_index: Option, event_sink: Arc, ) -> ActorInboxHandle { let (tx, inbox) = mpsc::channel(128); @@ -200,12 +197,7 @@ impl, - channel_id: String, - recovery: Arc, - forwards_updated_index: Option, ) { // Execute initial actions (e.g., BroadcastFundingTx for Broadcasting state) for action in initial_actions { @@ -390,83 +379,11 @@ impl { - let _ = self_tx - .send(ActorInput::PaymentSettled { - preimage: None, - updated_index: None, - }) - .await; - return; - } - Ok(ForwardActivity::AllFailed) => { - let _ = self_tx - .send(ActorInput::PaymentFailed { updated_index: None }) - .await; - return; - } - Ok(ForwardActivity::Offered) - | Ok(ForwardActivity::NoForwards) - | Err(_) => { - // Fall through to wait loop - } - } - - // Poll using wait subsystem - let mut current_index = from_index; - loop { - match recovery - .wait_for_forward_resolution(&channel_id, current_index) - .await - { - Ok((ForwardActivity::Settled, new_index)) => { - let _ = self_tx - .send(ActorInput::PaymentSettled { - preimage: None, - updated_index: Some(new_index), - }) - .await; - return; - } - Ok((ForwardActivity::AllFailed, new_index)) => { - let _ = self_tx - .send(ActorInput::PaymentFailed { - updated_index: Some(new_index), - }) - .await; - return; - } - Ok((ForwardActivity::Offered, new_index)) - | Ok((ForwardActivity::NoForwards, new_index)) => { - current_index = new_index; - continue; - } - Err(e) => { - warn!("forward monitoring error for scid={scid}: {e}"); - tokio::time::sleep(Duration::from_secs(5)).await; - continue; - } - } - } - }) - }; - - // Main loop: process inbox events + // Main loop: process inbox events from forward_event notifications loop { match self.inbox.recv().await { Some(actor_input) => { - // Only process the three shared persistence arms + // Only process settlement/failure/broadcast events let session_input = match &actor_input { ActorInput::PaymentSettled { .. } | ActorInput::PaymentFailed { .. } @@ -486,7 +403,6 @@ impl { sessions: Mutex>, - recovery_handles: Mutex>, datastore: Arc, executor: Arc, config: SessionConfig, @@ -49,7 +48,6 @@ impl pub fn new(datastore: Arc, executor: Arc, config: SessionConfig, event_sink: Arc) -> Self { Self { sessions: Mutex::new(HashMap::new()), - recovery_handles: Mutex::new(Vec::new()), datastore, executor, config, @@ -61,8 +59,13 @@ impl let entries = self.datastore.list_active_sessions().await?; for (scid, entry) in entries { + let payment_hash = entry.payment_hash.as_deref().and_then(|s| s.parse::().ok()); if let Some(handle) = self.recover_session(scid, entry, &recovery).await? { - self.recovery_handles.lock().await.push(handle); + if let Some(hash) = payment_hash { + self.sessions.lock().await.insert(hash, handle); + } else { + warn!("recovered session for scid={scid} has no payment_hash, dropping handle"); + } } } @@ -122,12 +125,35 @@ impl .await?; Ok(None) } - ForwardActivity::Offered | ForwardActivity::Settled => { - let forwards_updated_index = entry.forwards_updated_index; + ForwardActivity::Offered => { + let (session, initial_actions) = Session::recover( + channel_id.clone(), + funding_psbt.clone(), + None, + entry.opening_fee_params.clone(), + ); + + let handle = SessionActor::spawn_recovered_session_actor( + session, + entry, + initial_actions, + self.executor.clone(), + scid, + self.datastore.clone(), + self.event_sink.clone(), + ); + + Ok(Some(handle)) + } + ForwardActivity::Settled => { + // Forwards already settled — recover into Broadcasting state + // so the actor self-drives via BroadcastFundingTx without + // needing a forward_event notification from CLN. + let preimage = entry.preimage.clone().unwrap_or_default(); let (session, initial_actions) = Session::recover( channel_id.clone(), funding_psbt.clone(), - entry.preimage.clone(), + Some(preimage), entry.opening_fee_params.clone(), ); @@ -135,12 +161,9 @@ impl session, entry, initial_actions, - channel_id, self.executor.clone(), scid, self.datastore.clone(), - recovery.clone(), - forwards_updated_index, self.event_sink.clone(), ); @@ -694,13 +717,6 @@ mod tests { ) -> anyhow::Result<()> { Ok(()) } - async fn wait_for_forward_resolution( - &self, - _channel_id: &str, - from_index: u64, - ) -> anyhow::Result<(ForwardActivity, u64)> { - Ok((self.forward_activity.clone(), from_index + 1)) - } } #[tokio::test] @@ -821,4 +837,154 @@ mod tests { mgr.recover(recovery).await.unwrap(); assert_eq!(mgr.session_count().await, 0); } + + #[tokio::test] + async fn recover_funded_offered_registers_in_sessions() { + let mut ds = MockDatastore::new(); + ds.entries.clear(); + let mut entry = test_datastore_entry(); + entry.channel_id = Some("channel-1".to_string()); + entry.funding_psbt = Some("psbt-1".to_string()); + entry.payment_hash = Some(test_payment_hash(1).to_string()); + ds.entries.insert(test_scid().to_string(), entry); + + let mgr = Arc::new(SessionManager::new( + Arc::new(ds), + Arc::new(MockExecutor { fund_succeeds: true }), + SessionConfig::default(), + Arc::new(NoopEventSink), + )); + + let recovery = Arc::new(MockRecoveryProvider { + channel_exists: true, + forward_activity: ForwardActivity::Offered, + }); + + mgr.recover(recovery).await.unwrap(); + + // Recovered session must be reachable via on_payment_settled. + assert_eq!(mgr.session_count().await, 1); + let result = mgr.on_payment_settled(test_payment_hash(1), None, None).await; + assert!(result.is_ok()); + assert_eq!(mgr.session_count().await, 0); + } + + #[tokio::test] + async fn recover_funded_offered_reachable_by_on_payment_failed() { + let mut ds = MockDatastore::new(); + ds.entries.clear(); + let mut entry = test_datastore_entry(); + entry.channel_id = Some("channel-1".to_string()); + entry.funding_psbt = Some("psbt-1".to_string()); + entry.payment_hash = Some(test_payment_hash(1).to_string()); + ds.entries.insert(test_scid().to_string(), entry); + + let mgr = Arc::new(SessionManager::new( + Arc::new(ds), + Arc::new(MockExecutor { fund_succeeds: true }), + SessionConfig::default(), + Arc::new(NoopEventSink), + )); + + let recovery = Arc::new(MockRecoveryProvider { + channel_exists: true, + forward_activity: ForwardActivity::Offered, + }); + + mgr.recover(recovery).await.unwrap(); + assert_eq!(mgr.session_count().await, 1); + + let result = mgr.on_payment_failed(test_payment_hash(1), None).await; + assert!(result.is_ok()); + assert_eq!(mgr.session_count().await, 0); + } + + #[tokio::test] + async fn recover_funded_no_payment_hash_not_registered() { + let mut ds = MockDatastore::new(); + ds.entries.clear(); + let mut entry = test_datastore_entry(); + entry.channel_id = Some("channel-1".to_string()); + entry.funding_psbt = Some("psbt-1".to_string()); + entry.payment_hash = None; // No payment_hash + ds.entries.insert(test_scid().to_string(), entry); + + let mgr = Arc::new(SessionManager::new( + Arc::new(ds), + Arc::new(MockExecutor { fund_succeeds: true }), + SessionConfig::default(), + Arc::new(NoopEventSink), + )); + + let recovery = Arc::new(MockRecoveryProvider { + channel_exists: true, + forward_activity: ForwardActivity::Offered, + }); + + mgr.recover(recovery).await.unwrap(); + assert_eq!(mgr.session_count().await, 0); + } + + #[tokio::test] + async fn recover_funded_settled_registers_in_sessions() { + let mut ds = MockDatastore::new(); + ds.entries.clear(); + let mut entry = test_datastore_entry(); + entry.channel_id = Some("channel-1".to_string()); + entry.funding_psbt = Some("psbt-1".to_string()); + entry.payment_hash = Some(test_payment_hash(1).to_string()); + ds.entries.insert(test_scid().to_string(), entry); + + let mgr = Arc::new(SessionManager::new( + Arc::new(ds), + Arc::new(MockExecutor { fund_succeeds: true }), + SessionConfig::default(), + Arc::new(NoopEventSink), + )); + + let recovery = Arc::new(MockRecoveryProvider { + channel_exists: true, + forward_activity: ForwardActivity::Settled, + }); + + mgr.recover(recovery).await.unwrap(); + + // Settled sessions should still be registered (actor will receive + // BroadcastFundingTx as initial action and self-drive to completion). + assert_eq!(mgr.session_count().await, 1); + } + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn recovered_actor_settles_via_inbox() { + let mut ds = MockDatastore::new(); + ds.entries.clear(); + let mut entry = test_datastore_entry(); + entry.channel_id = Some("channel-1".to_string()); + entry.funding_psbt = Some("psbt-1".to_string()); + entry.payment_hash = Some(test_payment_hash(1).to_string()); + ds.entries.insert(test_scid().to_string(), entry.clone()); + + let mgr = Arc::new(SessionManager::new( + Arc::new(ds), + Arc::new(MockExecutor { fund_succeeds: true }), + SessionConfig::default(), + Arc::new(NoopEventSink), + )); + + let recovery = Arc::new(MockRecoveryProvider { + channel_exists: true, + forward_activity: ForwardActivity::Offered, + }); + + mgr.recover(recovery).await.unwrap(); + assert_eq!(mgr.session_count().await, 1); + + // Simulate forward_event delivering settlement. + let result = mgr.on_payment_settled(test_payment_hash(1), Some("preimage123".to_string()), Some(1)).await; + assert!(result.is_ok()); + assert_eq!(mgr.session_count().await, 0); + + // Give the actor time to finalize. + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + } } diff --git a/plugins/lsps-plugin/src/core/lsps2/provider.rs b/plugins/lsps-plugin/src/core/lsps2/provider.rs index 04f9dfaa3f0a..45be1bfef1ba 100644 --- a/plugins/lsps-plugin/src/core/lsps2/provider.rs +++ b/plugins/lsps-plugin/src/core/lsps2/provider.rs @@ -63,15 +63,6 @@ pub trait RecoveryProvider: Send + Sync { /// Close a channel and unreserve its inputs. async fn close_and_unreserve(&self, channel_id: &str, funding_psbt: &str) -> Result<()>; - - /// Monitor forward status changes using the wait subsystem. - /// Returns when a forward on the given channel settles or fails. - /// `from_index` is the last processed updated_index. - async fn wait_for_forward_resolution( - &self, - channel_id: &str, - from_index: u64, - ) -> Result<(ForwardActivity, u64)>; } #[async_trait]