Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions generic_a_star/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ deterministic_default_hasher = "0.14.2"
num-traits.workspace = true
serde = { workspace = true, features = ["derive"], optional = true }
extend_map = "0.14.4"
compare = "0.1.0"
111 changes: 111 additions & 0 deletions generic_a_star/src/comparator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
use std::cmp::Ordering;

use compare::Compare;

use crate::AStarNode;

#[derive(Debug, Default)]
pub struct AStarNodeComparator;

impl<T: AStarNode> Compare<T> for AStarNodeComparator {
fn compare(&self, l: &T, r: &T) -> Ordering {
l.cmp(r).reverse().then_with(|| {
l.secondary_maximisable_score()
.cmp(&r.secondary_maximisable_score())
})
}
}

#[cfg(test)]
mod tests {
use std::fmt::Display;

use compare::Compare;

use crate::{AStarNode, cost::U64Cost};

use super::AStarNodeComparator;

#[derive(Debug, PartialEq, Eq)]
struct Node {
cost: U64Cost,
lower_bound: U64Cost,
secondary_maximisable_score: usize,
}

impl Node {
fn new(cost: u64, lower_bound: u64, secondary_maximisable_score: usize) -> Self {
Self {
cost: cost.into(),
lower_bound: lower_bound.into(),
secondary_maximisable_score,
}
}
}

impl AStarNode for Node {
type Identifier = ();

type EdgeType = ();

type Cost = U64Cost;

fn identifier(&self) -> &Self::Identifier {
unimplemented!()
}

fn cost(&self) -> Self::Cost {
self.cost
}

fn a_star_lower_bound(&self) -> Self::Cost {
self.lower_bound
}

fn secondary_maximisable_score(&self) -> usize {
self.secondary_maximisable_score
}

fn predecessor(&self) -> Option<&Self::Identifier> {
unimplemented!()
}

fn predecessor_edge_type(&self) -> Option<Self::EdgeType> {
unimplemented!()
}
}

impl Display for Node {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{} + {}; {}",
self.cost, self.lower_bound, self.secondary_maximisable_score
)
}
}

impl Ord for Node {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
(self.cost + self.lower_bound).cmp(&(other.cost + other.lower_bound))
}
}

impl PartialOrd for Node {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

#[test]
fn compare() {
// Heap is a max heap, hence smaller nodes need to be bigger.
assert!(AStarNodeComparator.compares_eq(&Node::new(4, 5, 6), &Node::new(4, 5, 6)));
assert!(AStarNodeComparator.compares_gt(&Node::new(4, 4, 6), &Node::new(4, 5, 6)));
assert!(AStarNodeComparator.compares_gt(&Node::new(4, 4, 6), &Node::new(4, 5, 6)));
assert!(AStarNodeComparator.compares_gt(&Node::new(4, 5, 7), &Node::new(4, 5, 6)));
assert!(AStarNodeComparator.compares_lt(&Node::new(4, 5, 5), &Node::new(4, 5, 6)));
assert!(AStarNodeComparator.compares_lt(&Node::new(4, 6, 6), &Node::new(4, 5, 6)));
assert!(AStarNodeComparator.compares_lt(&Node::new(5, 5, 6), &Node::new(4, 5, 6)));
}
}
86 changes: 68 additions & 18 deletions generic_a_star/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
#![forbid(clippy::mod_module_files)]

use std::{
cmp::Ordering,
collections::HashMap,
fmt::{Debug, Display},
hash::Hash,
};

use binary_heap_plus::{BinaryHeap, MinComparator};
use binary_heap_plus::BinaryHeap;
use comparator::AStarNodeComparator;
use compare::Compare;
use cost::AStarCost;
use deterministic_default_hasher::DeterministicDefaultHasher;
use extend_map::ExtendFilter;
use num_traits::Bounded;
use num_traits::{Bounded, Zero};
use reset::Reset;

mod comparator;
pub mod cost;
pub mod reset;

Expand Down Expand Up @@ -43,6 +47,11 @@ pub trait AStarNode: Sized + Ord + Debug + Display {
/// Returns the A* lower bound of this node.
fn a_star_lower_bound(&self) -> Self::Cost;

/// Returns a score that is used to order nodes of the same cost.
///
/// This score should be maximised, which is done via complete search.
fn secondary_maximisable_score(&self) -> usize;

/// Returns the identifier of the predecessor of this node.
fn predecessor(&self) -> Option<&Self::Identifier>;

Expand Down Expand Up @@ -116,14 +125,14 @@ pub struct AStar<Context: AStarContext> {
Context::Node,
DeterministicDefaultHasher,
>,
open_list: BinaryHeap<Context::Node, MinComparator>,
open_list: BinaryHeap<Context::Node, AStarNodeComparator>,
performance_counters: AStarPerformanceCounters,
}

#[derive(Debug)]
pub struct AStarBuffers<NodeIdentifier, Node> {
closed_list: HashMap<NodeIdentifier, Node, DeterministicDefaultHasher>,
open_list: BinaryHeap<Node, MinComparator>,
open_list: BinaryHeap<Node, AStarNodeComparator>,
}

#[derive(Debug, Clone, Ord, PartialOrd, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -166,7 +175,7 @@ impl<Context: AStarContext> AStar<Context> {
state: AStarState::Empty,
context,
closed_list: Default::default(),
open_list: BinaryHeap::new_min(),
open_list: BinaryHeap::from_vec(Vec::new()),
performance_counters: Default::default(),
}
}
Expand Down Expand Up @@ -275,8 +284,11 @@ impl<Context: AStarContext> AStar<Context> {
self.state = AStarState::Searching;

let mut last_node = None;
let mut target_identifier = None;
let mut target_cost = <Context::Node as AStarNode>::Cost::max_value();
let mut target_secondary_maximisable_score = 0;

let target_identifier = loop {
loop {
let Some(node) = self.open_list.pop() else {
if last_node.is_none() {
unreachable!("Open list was empty.");
Expand All @@ -294,6 +306,7 @@ impl<Context: AStarContext> AStar<Context> {
}
};

// Check cost limit.
// Nodes are ordered by cost plus lower bound.
if node.cost() + node.a_star_lower_bound() > cost_limit {
self.state = AStarState::Terminated {
Expand All @@ -302,6 +315,7 @@ impl<Context: AStarContext> AStar<Context> {
return AStarResult::ExceededCostLimit { cost_limit };
}

// Check memory limit.
if self.closed_list.len() + self.open_list.len() > node_count_limit {
self.state = AStarState::Terminated {
result: AStarResult::ExceededMemoryLimit {
Expand All @@ -313,11 +327,19 @@ impl<Context: AStarContext> AStar<Context> {
};
}

// If label-correcting, abort when the first node more expensive than the cheapest target is visited.
if node.cost() + node.a_star_lower_bound() > target_cost {
debug_assert!(!self.context.is_label_setting());
break;
}

last_node = Some(node.identifier().clone());

if let Some(previous_visit) = self.closed_list.get(node.identifier()) {
self.performance_counters.suboptimal_opened_nodes += 1;

if self.context.is_label_setting() {
// In label-setting mode, if we have already visited the node, we now must be visiting it with a higher cost.
// In label-setting mode, if we have already visited the node, we now must be visiting it with a higher or equal cost.
debug_assert!(
previous_visit.cost() + previous_visit.a_star_lower_bound()
<= node.cost() + node.a_star_lower_bound(),
Expand All @@ -343,10 +365,13 @@ impl<Context: AStarContext> AStar<Context> {
out
}
);
}

self.performance_counters.suboptimal_opened_nodes += 1;
continue;
continue;
} else if AStarNodeComparator.compare(&node, previous_visit) != Ordering::Greater {
// If we are label-correcting, we may still find a better node later on.
// Skip if equal or worse.
continue;
}
}

let open_nodes_without_new_successors = self.open_list.len();
Expand All @@ -361,20 +386,41 @@ impl<Context: AStarContext> AStar<Context> {
self.performance_counters.opened_nodes +=
self.open_list.len() - open_nodes_without_new_successors;

if is_target(&self.context, &node) {
let identifier = node.identifier().clone();
let previous_visit = self.closed_list.insert(node.identifier().clone(), node);
self.performance_counters.closed_nodes += 1;
debug_assert!(previous_visit.is_none() || !self.context.is_label_setting());
break identifier;
let is_target = is_target(&self.context, &node);
debug_assert!(!is_target || node.a_star_lower_bound().is_zero());

if is_target
&& (node.cost() < target_cost
|| (node.cost() == target_cost
&& node.secondary_maximisable_score() > target_secondary_maximisable_score))
{
target_identifier = Some(node.identifier().clone());
target_cost = node.cost();
target_secondary_maximisable_score = node.secondary_maximisable_score();

if self.context.is_label_setting() {
let previous_visit = self.closed_list.insert(node.identifier().clone(), node);
self.performance_counters.closed_nodes += 1;
debug_assert!(previous_visit.is_none() || !self.context.is_label_setting());
break;
}
}

let previous_visit = self.closed_list.insert(node.identifier().clone(), node);
self.performance_counters.closed_nodes += 1;
debug_assert!(previous_visit.is_none() || !self.context.is_label_setting());
}

let Some(target_identifier) = target_identifier else {
debug_assert!(!self.context.is_label_setting());
self.state = AStarState::Terminated {
result: AStarResult::NoTarget,
};
return AStarResult::NoTarget;
};

let cost = self.closed_list.get(&target_identifier).unwrap().cost();
debug_assert_eq!(cost, target_cost);
self.state = AStarState::Terminated {
result: AStarResult::FoundTarget {
identifier: target_identifier.clone(),
Expand Down Expand Up @@ -549,11 +595,11 @@ impl<Context: AStarContext> Iterator for BacktrackingIteratorWithCost<'_, Contex
}
}

impl<NodeIdentifier, Node: Ord> Default for AStarBuffers<NodeIdentifier, Node> {
impl<NodeIdentifier, Node: AStarNode> Default for AStarBuffers<NodeIdentifier, Node> {
fn default() -> Self {
Self {
closed_list: Default::default(),
open_list: BinaryHeap::new_min(),
open_list: BinaryHeap::from_vec(Vec::new()),
}
}
}
Expand Down Expand Up @@ -599,6 +645,10 @@ impl<T: AStarNode> AStarNode for Box<T> {
<T as AStarNode>::a_star_lower_bound(self)
}

fn secondary_maximisable_score(&self) -> usize {
<T as AStarNode>::secondary_maximisable_score(self)
}

fn predecessor(&self) -> Option<&Self::Identifier> {
<T as AStarNode>::predecessor(self)
}
Expand Down
3 changes: 3 additions & 0 deletions lib_tsalign/src/a_star_aligner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ where
<Context::Node as AStarNode>::EdgeType: IAlignmentType,
{
info!("Aligning on subsequence {}", context.range());
debug!("Is label setting: {}", context.is_label_setting());

let start_time = Instant::now();

Expand Down Expand Up @@ -159,6 +160,7 @@ pub fn template_switch_distance_a_star_align<
>,
cost_limit: Option<Strategies::Cost>,
memory_limit: Option<usize>,
force_label_correcting: bool,
template_switch_count_memory: <Strategies::TemplateSwitchCount as TemplateSwitchCountStrategy>::Memory,
) -> AlignmentResult<template_switch_distance::AlignmentType, Strategies::Cost>
where
Expand Down Expand Up @@ -187,6 +189,7 @@ where
memory,
cost_limit,
memory_limit,
force_label_correcting,
));
debug!("CIGAR before extending: {}", result.cigar());

Expand Down
8 changes: 7 additions & 1 deletion lib_tsalign/src/a_star_aligner/alignment_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,13 @@ impl<Cost: AStarCost + From<u64>>
range.query_offset(),
config,
);
assert_eq!(initial_cost, (statistics.cost.round().raw() as u64).into());
let alignment_cost = (statistics.cost.round().raw() as u64).into();
assert_eq!(
initial_cost,
alignment_cost,
"computed cost {initial_cost} != alignment cost {alignment_cost}; {}",
alignment.cigar()
);

// Extend left with equal cost.
while range.reference_offset() > 0 && range.query_offset() > 0 {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,7 @@ mod tests {
LazyLock::new(|| TemplateSwitchConfig {
left_flank_length: 0,
right_flank_length: 0,
min_length: 3,
template_switch_min_length: 3,
base_cost: BaseCost {
rrf: 10u64.into(),
rqf: 100u64.into(),
Expand Down
Loading
Loading