diff --git a/src/model.rs b/src/model.rs index 55228d2..929dc81 100644 --- a/src/model.rs +++ b/src/model.rs @@ -28,7 +28,7 @@ pub type Turn = usize; pub type Utility = i64; /// TODO -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Debug, PartialEq)] pub enum SimpleUtility { WIN = 0, LOSE = 1, diff --git a/src/solver/algorithm/strong/cyclic.rs b/src/solver/algorithm/strong/cyclic.rs index 0b2a07f..0cf7771 100644 --- a/src/solver/algorithm/strong/cyclic.rs +++ b/src/solver/algorithm/strong/cyclic.rs @@ -9,6 +9,7 @@ //! - Ishir Garg, 3/12/2024 (ishirgarg@berkeley.edu) use anyhow::{Context, Result}; + use std::collections::{HashMap, VecDeque}; use crate::database::volatile; @@ -20,18 +21,7 @@ use crate::model::{PlayerCount, Remoteness, State, Turn}; use crate::solver::record::sur::RecordBuffer; use crate::solver::RecordType; -/* CONSTANTS */ - -/// The exact number of bits that are used to encode remoteness. -const REMOTENESS_SIZE: usize = 16; - -/// The maximum number of bits that can be used to encode a record. -const BUFFER_SIZE: usize = 128; - -/// The exact number of bits that are used to encode utility for one player. -const UTILITY_SIZE: usize = 2; - -pub fn two_player_zero_sum_dynamic_solver( +pub fn dynamic_solver( game: &G, mode: IOMode, ) -> Result<()> @@ -40,11 +30,11 @@ where { let mut db = volatile_database(game).context("Failed to initialize database.")?; - basic_loopy_solver(game, &mut db)?; + cyclic_solver(game, &mut db)?; Ok(()) } -fn basic_loopy_solver(game: &G, db: &mut D) -> Result<()> +fn cyclic_solver(game: &G, db: &mut D) -> Result<()> where G: DTransition + Bounded + SimpleSum<2> + Extensive<2> + Game, D: KVStore, @@ -91,7 +81,7 @@ where let parents = game.retrograde(child); // If child is a losing position - if matches!(child_utility, SimpleUtility::LOSE) { + if let SimpleUtility::LOSE = child_utility { for parent in parents { if *child_counts.get(&parent).expect("Failed to enqueue parent state in initial enqueueing stage") > 0 { // Add database entry diff --git a/src/solver/algorithm/strong/puzzle.rs b/src/solver/algorithm/strong/puzzle.rs index e414e18..c645043 100644 --- a/src/solver/algorithm/strong/puzzle.rs +++ b/src/solver/algorithm/strong/puzzle.rs @@ -6,6 +6,9 @@ //! - Ishir Garg (ishirgarg@berkeley.edu) use anyhow::{Context, Result}; +use bitvec::{order::Msb0, prelude::*, slice::BitSlice, store::BitStore}; + +use std::collections::{HashMap, VecDeque}; use crate::database::volatile; use crate::database::{KVStore, Tabular}; @@ -15,8 +18,6 @@ use crate::model::SimpleUtility; use crate::model::{Remoteness, State}; use crate::solver::error::SolverError::SolverViolation; use crate::solver::record::surcc::{ChildCount, RecordBuffer}; -use bitvec::{order::Msb0, prelude::*, slice::BitSlice, store::BitStore}; -use std::collections::{HashMap, HashSet, VecDeque}; pub fn dynamic_solver(game: &G, mode: IOMode) -> Result<()> where @@ -34,6 +35,12 @@ where Ok(()) } +/// Runs BFS starting from the ending primitive positions of a game, and working +/// its way up the game tree in reverse. Assigns a remoteness and simple +/// utiliity to every winning and losing position. Draws (positions where +/// winning is impossible, but it is possible to play forever without losing) +/// not assigned a remoteness. This implementation uses the SURCC record to +/// store child count along with utility and remoteness. fn reverse_bfs_solver(db: &mut D, game: &G) -> Result<()> where G: DTransition @@ -42,7 +49,6 @@ where + Game, D: KVStore, { - // Get end states and create frontiers let end_states = discover_child_counts(db, game)?; let mut winning_queue: VecDeque = VecDeque::new(); @@ -63,18 +69,34 @@ where })?, } // Add ending state utility and remoteness to database - update_db_record(db, end_state, game.utility(end_state), 0, 0)?; + update_db_record(db, end_state, ClassicPuzzle::utility(game, end_state), 0, 0)?; } - // Perform BFS on winning states + reverse_bfs_winning_states(db, game, &mut winning_queue)?; + reverse_bfs_losing_states(db, game, &mut losing_queue)?; + + Ok(()) +} + +/// Performs BFS on winning states, marking visited states as a win +fn reverse_bfs_winning_states( + db: &mut D, + game: &G, + winning_queue: &mut VecDeque +) -> Result<()> +where + G: DTransition + + Bounded + + ClassicPuzzle + + Game, + D: KVStore, +{ while let Some(state) = winning_queue.pop_front() { let buf = RecordBuffer::from(db.get(state).unwrap())?; - let child_remoteness = - RecordBuffer::from(db.get(state).unwrap())?.get_remoteness(); + let child_remoteness = buf.get_remoteness(); for parent in game.retrograde(state) { - let child_count = - RecordBuffer::from(db.get(parent).unwrap())?.get_child_count(); + let child_count = RecordBuffer::from(db.get(parent).unwrap())?.get_child_count(); if child_count > 0 { winning_queue.push_back(parent); update_db_record( @@ -87,10 +109,24 @@ where } } } + + Ok(()) +} - // Perform BFS on losing states, where remoteness is the longest path to a - // losing primitive - // position. +/// Performs BFS on losing states, marking visited states as a loss. Remoteness +/// is the shortest path to a primitive losing position. +fn reverse_bfs_losing_states( + db: &mut D, + game: &G, + losing_queue: &mut VecDeque +)-> Result<()> +where + G: DTransition + + Bounded + + ClassicPuzzle + + Game, + D: KVStore, +{ while let Some(state) = losing_queue.pop_front() { let parents = game.retrograde(state); let child_remoteness = @@ -122,7 +158,7 @@ where } } } - + Ok(()) } @@ -160,14 +196,12 @@ where D: KVStore, { let mut end_states = Vec::new(); - discover_child_counts_helper(db, game, game.start(), &mut end_states)?; + discover_child_counts_from_state(db, game, game.start(), &mut end_states)?; Ok(end_states) } -/// Adds child counts for each position to the database -/// Also returns a vector of all primitive positions -fn discover_child_counts_helper( +fn discover_child_counts_from_state( db: &mut D, game: &G, state: State, @@ -197,16 +231,16 @@ where db.put(state, &buf); // We need to check both prograde and retrograde; consider a game with 3 - // nodes where 0-->2 and 1-->2. Then, starting from node 0 with only - // progrades would discover states 0 and 1; we need to include retrogrades - // to discover state 2. + // nodes where the edges are `0` → `2` and `1` → `2`. Then, starting from + // node 0 with only progrades would discover states 0 and 1; we need to + // include retrogrades to discover state 2. for &child in game .prograde(state) .iter() .chain(game.retrograde(state).iter()) { if db.get(child).is_none() { - discover_child_counts_helper(db, game, child, end_states)?; + discover_child_counts_from_state(db, game, child, end_states)?; } } @@ -227,7 +261,7 @@ where let id = game.id(); let db = volatile::Database::initialize(); - let schema = RecordType::SUR(1) + let schema = RecordType::SURCC(1) .try_into() .context("Failed to create table schema for solver records.")?; db.create_table(&id, schema) @@ -240,7 +274,6 @@ where } */ -// THIS IS ONLY FOR TESTING PURPOSES struct TestDB { memory: HashMap>, } @@ -271,7 +304,7 @@ impl KVStore for TestDB { } fn del(&mut self, key: State) { - unimplemented![]; + unimplemented!(); } } @@ -285,6 +318,9 @@ where #[cfg(test)] mod tests { + use anyhow::Result; + + use crate::game::mock::{Session, SessionBuilder}; use crate::database::{KVStore, Tabular}; use crate::game::mock; use crate::game::{ @@ -296,11 +332,9 @@ mod tests { use crate::model::{State, Turn}; use crate::node; use crate::solver::record::surcc::RecordBuffer; - use anyhow::Result; - use std::collections::{HashMap, VecDeque}; use super::{ - discover_child_counts, reverse_bfs_solver, volatile_database, TestDB, + reverse_bfs_solver, volatile_database, }; struct GameNode { @@ -390,10 +424,10 @@ mod tests { let mut db = volatile_database(&graph)?; reverse_bfs_solver(&mut db, &graph); - assert!(matches!( + assert_eq!( RecordBuffer::from(db.get(0).unwrap())?.get_utility(0)?, SimpleUtility::WIN - )); + ); assert_eq!( RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), 0 @@ -421,14 +455,14 @@ mod tests { let mut db = volatile_database(&graph)?; reverse_bfs_solver(&mut db, &graph); - assert!(matches!( + assert_eq!( RecordBuffer::from(db.get(0).unwrap())?.get_utility(0)?, SimpleUtility::WIN - )); - assert!(matches!( + ); + assert_eq!( RecordBuffer::from(db.get(1).unwrap())?.get_utility(0)?, SimpleUtility::WIN - )); + ); assert_eq!( RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), @@ -480,26 +514,14 @@ mod tests { )); } - assert_eq!( - RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), - 1 - ); - assert_eq!( - RecordBuffer::from(db.get(1).unwrap())?.get_remoteness(), - 2 - ); - assert_eq!( - RecordBuffer::from(db.get(2).unwrap())?.get_remoteness(), - 1 - ); - assert_eq!( - RecordBuffer::from(db.get(3).unwrap())?.get_remoteness(), - 1 - ); - assert_eq!( - RecordBuffer::from(db.get(4).unwrap())?.get_remoteness(), - 0 - ); + let expected_remoteness = [1, 2, 1, 1, 0]; + + for (i, &remoteness) in expected_remoteness.iter().enumerate() { + assert_eq!( + RecordBuffer::from(db.get(i as u64).unwrap())?.get_remoteness(), + remoteness + ) + } Ok(()) } @@ -540,10 +562,10 @@ mod tests { reverse_bfs_solver(&mut db, &graph); for i in 0..5 { - assert!(matches!( + assert_eq!( RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, SimpleUtility::DRAW - )); + ); } Ok(()) @@ -596,61 +618,33 @@ mod tests { let mut db = volatile_database(&graph)?; reverse_bfs_solver(&mut db, &graph); - for i in 0..=5 { - assert!(matches!( - RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, - SimpleUtility::WIN - )); - } - assert!(matches!( - RecordBuffer::from(db.get(6).unwrap())?.get_utility(0)?, - SimpleUtility::LOSE - )); - assert!(matches!( - RecordBuffer::from(db.get(7).unwrap())?.get_utility(0)?, - SimpleUtility::LOSE - )); - assert!(matches!( - RecordBuffer::from(db.get(8).unwrap())?.get_utility(0)?, + let expected_utilities = [ + SimpleUtility::WIN, + SimpleUtility::WIN, + SimpleUtility::WIN, + SimpleUtility::WIN, + SimpleUtility::WIN, + SimpleUtility::WIN, + SimpleUtility::LOSE, + SimpleUtility::LOSE, SimpleUtility::LOSE - )); + ]; - assert_eq!( - RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), - 2 - ); - assert_eq!( - RecordBuffer::from(db.get(1).unwrap())?.get_remoteness(), - 4 - ); - assert_eq!( - RecordBuffer::from(db.get(2).unwrap())?.get_remoteness(), - 4 - ); - assert_eq!( - RecordBuffer::from(db.get(3).unwrap())?.get_remoteness(), - 1 - ); - assert_eq!( - RecordBuffer::from(db.get(4).unwrap())?.get_remoteness(), - 3 - ); - assert_eq!( - RecordBuffer::from(db.get(5).unwrap())?.get_remoteness(), - 0 - ); - assert_eq!( - RecordBuffer::from(db.get(6).unwrap())?.get_remoteness(), - 1 - ); - assert_eq!( - RecordBuffer::from(db.get(7).unwrap())?.get_remoteness(), - 2 - ); - assert_eq!( - RecordBuffer::from(db.get(8).unwrap())?.get_remoteness(), - 0 - ); + let expected_remoteness = [2, 4, 4, 1, 3, 0, 1, 2, 0]; + + for (i, &utility) in expected_utilities.iter().enumerate() { + assert_eq!( + RecordBuffer::from(db.get(i as u64).unwrap())?.get_utility(0)?, + utility + ); + } + + for (i, &remoteness) in expected_remoteness.iter().enumerate() { + assert_eq!( + RecordBuffer::from(db.get(i as u64).unwrap())?.get_remoteness(), + remoteness + ); + } Ok(()) } @@ -688,8 +682,8 @@ mod tests { children: vec![8], }, GameNode { - utility: None, - children: vec![6, 8, 13], + utility: Some(SimpleUtility::LOSE), + children: vec![9, 2], }, GameNode { utility: Some(SimpleUtility::LOSE), @@ -701,15 +695,15 @@ mod tests { }, GameNode { utility: Some(SimpleUtility::LOSE), - children: vec![11], + children: vec![7], }, GameNode { - utility: Some(SimpleUtility::LOSE), - children: vec![9, 2], + utility: None, + children: vec![6, 8, 13], }, GameNode { utility: Some(SimpleUtility::LOSE), - children: vec![7], + children: vec![11], }, GameNode { utility: Some(SimpleUtility::LOSE), @@ -722,83 +716,38 @@ mod tests { let mut db = volatile_database(&graph)?; reverse_bfs_solver(&mut db, &graph); - for i in 0..=5 { - assert!(matches!( - RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, - SimpleUtility::WIN - )); + let expected_utilities = [ + SimpleUtility::WIN, + SimpleUtility::WIN, + SimpleUtility::WIN, + SimpleUtility::WIN, + SimpleUtility::WIN, + SimpleUtility::WIN, + SimpleUtility::LOSE, + SimpleUtility::WIN, + SimpleUtility::LOSE, + SimpleUtility::WIN, + SimpleUtility::WIN, + SimpleUtility::DRAW, + SimpleUtility::DRAW, + SimpleUtility::DRAW, + ]; + + let expected_remoteness = [2, 1, 4, 1, 3, 0, 1, 5, 0, 7, 6]; + + for (i, &utility) in expected_utilities.iter().enumerate() { + assert_eq!( + RecordBuffer::from(db.get(i as u64).unwrap())?.get_utility(0)?, + utility + ); } - assert!(matches!( - RecordBuffer::from(db.get(6).unwrap())?.get_utility(0)?, - SimpleUtility::LOSE - )); - assert!(matches!( - RecordBuffer::from(db.get(7).unwrap())?.get_utility(0)?, - SimpleUtility::DRAW - )); - assert!(matches!( - RecordBuffer::from(db.get(8).unwrap())?.get_utility(0)?, - SimpleUtility::LOSE - )); - for i in 9..=11 { - assert!(matches!( - RecordBuffer::from(db.get(i).unwrap())?.get_utility(0)?, - SimpleUtility::WIN - )); - } - assert!(matches!( - RecordBuffer::from(db.get(12).unwrap())?.get_utility(0)?, - SimpleUtility::DRAW - )); - assert!(matches!( - RecordBuffer::from(db.get(13).unwrap())?.get_utility(0)?, - SimpleUtility::DRAW - )); - assert_eq!( - RecordBuffer::from(db.get(0).unwrap())?.get_remoteness(), - 2 - ); - assert_eq!( - RecordBuffer::from(db.get(1).unwrap())?.get_remoteness(), - 1 - ); - assert_eq!( - RecordBuffer::from(db.get(2).unwrap())?.get_remoteness(), - 4 - ); - assert_eq!( - RecordBuffer::from(db.get(3).unwrap())?.get_remoteness(), - 1 - ); - assert_eq!( - RecordBuffer::from(db.get(4).unwrap())?.get_remoteness(), - 3 - ); - assert_eq!( - RecordBuffer::from(db.get(5).unwrap())?.get_remoteness(), - 0 - ); - assert_eq!( - RecordBuffer::from(db.get(6).unwrap())?.get_remoteness(), - 1 - ); - assert_eq!( - RecordBuffer::from(db.get(8).unwrap())?.get_remoteness(), - 0 - ); - assert_eq!( - RecordBuffer::from(db.get(9).unwrap())?.get_remoteness(), - 7 - ); - assert_eq!( - RecordBuffer::from(db.get(10).unwrap())?.get_remoteness(), - 6 - ); - assert_eq!( - RecordBuffer::from(db.get(11).unwrap())?.get_remoteness(), - 5 - ); + for (i, &remoteness) in expected_remoteness.iter().enumerate() { + assert_eq!( + RecordBuffer::from(db.get(i as u64).unwrap())?.get_remoteness(), + remoteness + ); + } Ok(()) }