Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Verify incoming TransmissionResponses #3083

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 13 additions & 19 deletions node/bft/ledger-service/src/ledger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::{fmt_id, spawn_blocking, LedgerService};
use crate::{spawn_blocking, LedgerService};
use snarkvm::{
ledger::{
block::{Block, Transaction},
Expand Down Expand Up @@ -171,8 +171,8 @@ impl<N: Network, C: ConsensusStorage<N>> LedgerService<N> for CoreLedgerService<
}
}

/// Ensures the given transmission ID matches the given transmission.
fn ensure_transmission_id_matches(
/// Checks the given transmission is well-formed and unique.
async fn check_transmission_basic(
&self,
transmission_id: TransmissionID<N>,
transmission: &mut Transmission<N>,
Expand All @@ -182,16 +182,13 @@ impl<N: Network, C: ConsensusStorage<N>> LedgerService<N> for CoreLedgerService<
(TransmissionID::Transaction(expected_transaction_id), Transmission::Transaction(transaction_data)) => {
match transaction_data.clone().deserialize_blocking() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not ideal that we do blocking deserialization in an async context; we should either do it before calling check_transaction_basic (with a deserialized transaction), or use a blocking task here

Ok(transaction) => {
if transaction.id() != expected_transaction_id {
bail!(
"Received mismatching transaction ID - expected {}, found {}",
fmt_id(expected_transaction_id),
fmt_id(transaction.id()),
);
}
let deserialized_transaction = Data::Object(transaction);

// Check that the transaction is valid.
self.check_transaction_basic(expected_transaction_id, deserialized_transaction.clone()).await?;

// Update the transmission with the deserialized transaction.
*transaction_data = Data::Object(transaction);
*transaction_data = deserialized_transaction;
}
Err(err) => {
bail!("Failed to deserialize transaction: {err}");
Expand All @@ -201,16 +198,13 @@ impl<N: Network, C: ConsensusStorage<N>> LedgerService<N> for CoreLedgerService<
(TransmissionID::Solution(expected_commitment), Transmission::Solution(solution_data)) => {
match solution_data.clone().deserialize_blocking() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Ok(solution) => {
if solution.commitment() != expected_commitment {
bail!(
"Received mismatching solution ID - expected {}, found {}",
fmt_id(expected_commitment),
fmt_id(solution.commitment()),
);
}
let deserialized_solution = Data::Object(solution);

// Check that the solution is valid.
self.check_solution_basic(expected_commitment, deserialized_solution.clone()).await?;

// Update the transmission with the deserialized solution.
*solution_data = Data::Object(solution);
*solution_data = deserialized_solution;
}
Err(err) => {
bail!("Failed to deserialize solution: {err}");
Expand Down
6 changes: 3 additions & 3 deletions node/bft/ledger-service/src/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,13 @@ impl<N: Network> LedgerService<N> for MockLedgerService<N> {
Ok(false)
}

/// Ensures the given transmission ID matches the given transmission.
fn ensure_transmission_id_matches(
/// Checks the given transmission is well-formed and unique.
async fn check_transmission_basic(
&self,
transmission_id: TransmissionID<N>,
_transmission: &mut Transmission<N>,
) -> Result<()> {
trace!("[MockLedgerService] Ensure transmission ID matches {:?} - Ok", fmt_id(transmission_id));
trace!("[MockLedgerService] Check transmission basic {:?} - Ok", fmt_id(transmission_id));
Ok(())
}

Expand Down
4 changes: 2 additions & 2 deletions node/bft/ledger-service/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ impl<N: Network> LedgerService<N> for ProverLedgerService<N> {
bail!("Transmission '{transmission_id}' does not exist in prover")
}

/// Ensures the given transmission ID matches the given transmission.
fn ensure_transmission_id_matches(
/// Checks the given transmission is well-formed and unique.
async fn check_transmission_basic(
&self,
_transmission_id: TransmissionID<N>,
_transmission: &mut Transmission<N>,
Expand Down
4 changes: 2 additions & 2 deletions node/bft/ledger-service/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ pub trait LedgerService<N: Network>: Debug + Send + Sync {
/// Returns `true` if the ledger contains the given transmission ID.
fn contains_transmission(&self, transmission_id: &TransmissionID<N>) -> Result<bool>;

/// Ensures the given transmission ID matches the given transmission.
fn ensure_transmission_id_matches(
/// Checks the given transmission is well-formed and unique.
async fn check_transmission_basic(
&self,
transmission_id: TransmissionID<N>,
transmission: &mut Transmission<N>,
Expand Down
2 changes: 1 addition & 1 deletion node/bft/ledger-service/src/translucent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl<N: Network, C: ConsensusStorage<N>> LedgerService<N> for TranslucentLedgerS
}

/// Always succeeds.
fn ensure_transmission_id_matches(
async fn check_transmission_basic(
&self,
_transmission_id: TransmissionID<N>,
_transmission: &mut Transmission<N>,
Expand Down
24 changes: 13 additions & 11 deletions node/bft/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ impl<N: Network> Worker<N> {
self.spawn(async move {
while let Some((peer_ip, transmission_response)) = rx_transmission_response.recv().await {
// Process the transmission response.
self_.finish_transmission_request(peer_ip, transmission_response);
self_.finish_transmission_request(peer_ip, transmission_response).await;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to process the transmission response entries in the order they arrive? if not, we could spawn tasks that performs this action, in order to be able to perform multiple of them concurrently

}
});
}
Expand Down Expand Up @@ -393,19 +393,19 @@ impl<N: Network> Worker<N> {

/// Handles the incoming transmission response.
/// This method ensures the transmission response is well-formed and matches the transmission ID.
fn finish_transmission_request(&self, peer_ip: SocketAddr, response: TransmissionResponse<N>) {
async fn finish_transmission_request(&self, peer_ip: SocketAddr, response: TransmissionResponse<N>) {
let TransmissionResponse { transmission_id, mut transmission } = response;
// Check if the peer IP exists in the pending queue for the given transmission ID.
let exists = self.pending.get(transmission_id).unwrap_or_default().contains(&peer_ip);
// If the peer IP exists, finish the pending request.
if exists {
// Ensure the transmission ID matches the transmission.
match self.ledger.ensure_transmission_id_matches(transmission_id, &mut transmission) {
// Check that the given transmission is well-formed and unique.
match self.ledger.check_transmission_basic(transmission_id, &mut transmission).await {
Ok(()) => {
// Remove the transmission ID from the pending queue.
self.pending.remove(transmission_id, Some(transmission));
}
Err(err) => warn!("Failed to finish transmission response from peer '{peer_ip}': {err}"),
Err(err) => warn!("Malicious peer ('{peer_ip}') sent an invalid transmission: {err}"),
};
}
}
Expand Down Expand Up @@ -487,7 +487,7 @@ mod tests {
fn get_previous_committee_for_round(&self, round: u64) -> Result<Committee<N>>;
fn contains_certificate(&self, certificate_id: &Field<N>) -> Result<bool>;
fn contains_transmission(&self, transmission_id: &TransmissionID<N>) -> Result<bool>;
fn ensure_transmission_id_matches(
async fn check_transmission_basic(
&self,
transmission_id: TransmissionID<N>,
transmission: &mut Transmission<N>,
Expand Down Expand Up @@ -558,7 +558,7 @@ mod tests {
});
let mut mock_ledger = MockLedger::default();
mock_ledger.expect_current_committee().returning(move || Ok(committee.clone()));
mock_ledger.expect_ensure_transmission_id_matches().returning(|_, _| Ok(()));
mock_ledger.expect_check_transmission_basic().returning(|_, _| Ok(()));
let ledger: Arc<dyn LedgerService<CurrentNetwork>> = Arc::new(mock_ledger);
// Initialize the storage.
let storage = Storage::<CurrentNetwork>::new(ledger.clone(), Arc::new(BFTMemoryService::new()), 1);
Expand All @@ -572,10 +572,12 @@ mod tests {
assert!(worker.pending.contains(transmission_id));
let peer_ip = SocketAddr::from(([127, 0, 0, 1], 1234));
// Fake the transmission response.
worker.finish_transmission_request(peer_ip, TransmissionResponse {
transmission_id,
transmission: Transmission::Solution(Data::Buffer(Bytes::from(vec![0; 512]))),
});
worker
.finish_transmission_request(peer_ip, TransmissionResponse {
transmission_id,
transmission: Transmission::Solution(Data::Buffer(Bytes::from(vec![0; 512]))),
})
.await;
// Check the transmission was removed from the pending set.
assert!(!worker.pending.contains(transmission_id));
}
Expand Down