Skip to content

Commit

Permalink
grammargen: avoid the same alternative in succession
Browse files Browse the repository at this point in the history
  • Loading branch information
RavuAlHemio committed Sep 24, 2021
1 parent d142ef7 commit 3c0709f
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 63 deletions.
3 changes: 1 addition & 2 deletions rocketbot_plugin_grammargen/examples/grammargenerate.rs
Expand Up @@ -4,7 +4,6 @@ use std::ffi::OsString;
use std::fs::File;
use std::io::Read;
use std::path::PathBuf;
use std::sync::{Arc, Mutex};

use rand::{Rng, SeedableRng};
use rand::rngs::StdRng;
Expand Down Expand Up @@ -87,7 +86,7 @@ async fn main() {
let mut state = GeneratorState::new_topmost(
rulebook,
conditions,
Arc::new(Mutex::new(rng)),
rng,
);

if verify {
Expand Down
137 changes: 87 additions & 50 deletions rocketbot_plugin_grammargen/src/grammar.rs
Expand Up @@ -18,50 +18,70 @@ static FIRST_LETTER_RE: Lazy<Regex> = Lazy::new(|| Regex::new(
const MAX_STACK_DEPTH: usize = 128;


#[derive(Debug)]
pub struct LockedGeneratorState {
pub rng: StdRng,
pub memories: HashMap<usize, Result<String, SoundnessError>>,
pub regex_cache: HashMap<String, Regex>,
pub sound_productions: HashSet<usize>,
pub previous_alternative: HashMap<usize, usize>,
}
impl LockedGeneratorState {
pub fn new(
rng: StdRng,
memories: HashMap<usize, Result<String, SoundnessError>>,
regex_cache: HashMap<String, Regex>,
sound_productions: HashSet<usize>,
previous_alternative: HashMap<usize, usize>,
) -> Self {
Self {
rng,
memories,
regex_cache,
sound_productions,
previous_alternative,
}
}
}

#[derive(Debug)]
pub struct GeneratorState {
pub rulebook: Rulebook,
pub conditions: HashSet<String>,
pub rng: Arc<Mutex<StdRng>>,
pub memories: Arc<Mutex<HashMap<usize, Result<String, SoundnessError>>>>,
pub prod_stack: Vec<usize>,
pub regex_cache: Arc<Mutex<HashMap<String, Regex>>>,
pub sound_productions: Arc<Mutex<HashSet<usize>>>,
pub locked: Arc<Mutex<LockedGeneratorState>>,
}
impl GeneratorState {
pub fn new(
rulebook: Rulebook,
conditions: HashSet<String>,
rng: Arc<Mutex<StdRng>>,
memories: Arc<Mutex<HashMap<usize, Result<String, SoundnessError>>>>,
prod_stack: Vec<usize>,
regex_cache: Arc<Mutex<HashMap<String, Regex>>>,
sound_productions: Arc<Mutex<HashSet<usize>>>,
locked: Arc<Mutex<LockedGeneratorState>>,
) -> GeneratorState {
GeneratorState {
rulebook,
conditions,
rng,
memories,
prod_stack,
regex_cache,
sound_productions,
locked,
}
}

pub fn new_topmost(
rulebook: Rulebook,
conditions: HashSet<String>,
rng: Arc<Mutex<StdRng>>,
rng: StdRng,
) -> Self {
Self::new(
rulebook,
conditions,
rng,
Arc::new(Mutex::new(HashMap::new())),
Vec::new(),
Arc::new(Mutex::new(HashMap::new())),
Arc::new(Mutex::new(HashSet::new())),
Arc::new(Mutex::new(LockedGeneratorState::new(
rng,
HashMap::new(),
HashMap::new(),
HashSet::new(),
HashMap::new(),
))),
)
}

Expand All @@ -82,9 +102,9 @@ impl GeneratorState {
}

pub fn get_or_compile_regex(&self, regex_str: &str) -> Result<Regex, SoundnessError> {
let mut regex_guard = self.regex_cache
.lock().expect("locking regex cache failed");
match regex_guard.entry(regex_str.to_owned()) {
let mut lock_guard = self.locked
.lock().expect("locking failed");
match lock_guard.regex_cache.entry(regex_str.to_owned()) {
HashMapEntry::Occupied(oe) => Ok(oe.get().clone()),
HashMapEntry::Vacant(ve) => {
let regex = match Regex::new(regex_str) {
Expand All @@ -103,19 +123,13 @@ impl Clone for GeneratorState {
fn clone(&self) -> Self {
let rulebook = self.rulebook.clone();
let conditions = self.conditions.clone();
let rng = Arc::clone(&self.rng);
let memories = Arc::clone(&self.memories);
let prod_stack = self.prod_stack.clone();
let regex_cache = Arc::clone(&self.regex_cache);
let sound_productions = Arc::clone(&self.sound_productions);
let locked = Arc::clone(&self.locked);
GeneratorState::new(
rulebook,
conditions,
rng,
memories,
prod_stack,
regex_cache,
sound_productions,
locked,
)
}
}
Expand Down Expand Up @@ -356,39 +370,61 @@ impl Production {
Ok(ret)
},
ProductionKind::Choice { options } => {
let my_alternatives: Vec<&Alternative> = options
let last_time = {
state.locked
.lock().expect("lock failed")
.previous_alternative.get(&self.prod_id)
.map(|i| *i)
};

let mut my_alternatives: Vec<(usize, &Alternative)> = options
.iter()
.filter(|alt| alt.conditions.iter().all(|cond|
.enumerate()
.filter(|(_i, alt)| alt.conditions.iter().all(|cond|
state.conditions.contains(&cond.identifier) != cond.negated
))
.collect();
if my_alternatives.len() == 1 {
// fast-path
return my_alternatives[0].generate(state);
return my_alternatives[0].1.generate(state);
}

// if there is more than one alternative, remove the one we generated previously
if let Some(lt) = last_time {
my_alternatives.retain(|(i, _alt)| *i != lt);
}

let total_weight: BigUint = my_alternatives
.iter()
.map(|alt| &alt.weight)
.map(|(_i, alt)| &alt.weight)
.sum();

if total_weight == Zero::zero() {
// this branch has been "sawed off"
return Err(SoundnessError::NoAlternatives);
}

let mut random_weight = {
let mut rng_guard = state.rng.lock().unwrap();
rng_guard.gen_biguint_range(&Zero::zero(), &total_weight)
state.locked
.lock().expect("failed to lock")
.rng.gen_biguint_range(&Zero::zero(), &total_weight)
};

for alternative in my_alternatives {
for (i, alternative) in &my_alternatives {
if random_weight >= alternative.weight {
random_weight -= &alternative.weight;
continue;
}

return alternative.generate(state)
let generated = alternative.generate(state);

// remember what we did here today
{
state.locked
.lock().expect("lock failed")
.previous_alternative.insert(self.prod_id, *i);
}

return generated;
}

unreachable!();
Expand All @@ -397,8 +433,8 @@ impl Production {
let hundred = BigUint::from(100u8);

let rand_val = {
let mut rng_guard = state.rng.lock().unwrap();
rng_guard.gen_biguint_range(&Zero::zero(), &hundred)
let mut lock_guard = state.locked.lock().unwrap();
lock_guard.rng.gen_biguint_range(&Zero::zero(), &hundred)
};

if &rand_val < weight {
Expand All @@ -417,8 +453,8 @@ impl Production {

loop {
let rand_bool: bool = {
let mut rng_guard = state.rng.lock().unwrap();
rng_guard.gen()
let mut lock_guard = state.locked.lock().unwrap();
lock_guard.rng.gen()
};
if rand_bool {
break;
Expand All @@ -442,8 +478,8 @@ impl Production {

if rule.memoize {
// have we generated this yet?
let memo_guard = state.memories.lock().unwrap();
if let Some(memoized) = memo_guard.get(&self.prod_id) {
let lock_guard = state.locked.lock().unwrap();
if let Some(memoized) = lock_guard.memories.get(&self.prod_id) {
// yup
return memoized.clone();
}
Expand Down Expand Up @@ -476,8 +512,8 @@ impl Production {

if rule.memoize {
// remember me
let mut memo_guard = state.memories.lock().unwrap();
memo_guard.insert(self.prod_id, generated.clone());
let mut lock_guard = state.locked.lock().unwrap();
lock_guard.memories.insert(self.prod_id, generated.clone());
}

generated
Expand Down Expand Up @@ -768,9 +804,9 @@ impl TextGenerator for Production {

// are we already verified?
if self.prod_id != 0 {
let sound_prod_guard = state.sound_productions
.lock().expect("failed to lock set of sound productions");
if sound_prod_guard.contains(&self.prod_id) {
let lock_guard = state.locked
.lock().expect("failed to lock");
if lock_guard.sound_productions.contains(&self.prod_id) {
// yes; don't verify us again
state.prod_stack.pop();
return Ok(());
Expand All @@ -781,8 +817,9 @@ impl TextGenerator for Production {

if let Ok(_) = result {
// mark us as verified
state.sound_productions
.lock().expect("failed to lock set of sound productions")
state.locked
.lock().expect("failed to lock")
.sound_productions
.insert(self.prod_id);
}

Expand Down
15 changes: 4 additions & 11 deletions rocketbot_plugin_grammargen/src/lib.rs
Expand Up @@ -6,7 +6,7 @@ use std::collections::{HashMap, HashSet};
use std::fs::File;
use std::io::Read;
use std::path::PathBuf;
use std::sync::{Arc, Mutex, Weak};
use std::sync::Weak;

use async_trait::async_trait;
use log::error;
Expand All @@ -27,7 +27,6 @@ pub struct GrammarGenPlugin {
grammars: HashMap<String, Rulebook>,
grammar_to_allowed_channel_names: HashMap<String, Option<HashSet<String>>>,
word_joiner_in_nicknames: bool,
rng: Arc<Mutex<StdRng>>,
}
#[async_trait]
impl RocketBotPlugin for GrammarGenPlugin {
Expand Down Expand Up @@ -108,16 +107,11 @@ impl RocketBotPlugin for GrammarGenPlugin {
my_interface.register_channel_command(&this_grammar_command).await;
}

let rng = Arc::new(Mutex::new(
StdRng::from_entropy(),
));

GrammarGenPlugin {
interface,
grammars,
grammar_to_allowed_channel_names,
word_joiner_in_nicknames,
rng,
}
}

Expand Down Expand Up @@ -178,16 +172,15 @@ impl RocketBotPlugin for GrammarGenPlugin {
let mut my_grammar = grammar.clone();
my_grammar.add_builtins(&channel_nicks, chosen_nick_opt);

let rng = Arc::clone(&self.rng);
let mut rng = StdRng::from_entropy();
let mut conditions = HashSet::new();

// process metacommands
{
let mut rng_guard = rng.lock().unwrap();
for metacommand in &my_grammar.metacommands {
match metacommand {
Metacommand::RandomizeCondition(cond) => {
let activate_condition: bool = rng_guard.gen();
let activate_condition: bool = rng.gen();
if activate_condition {
conditions.insert(cond.clone());
}
Expand All @@ -203,7 +196,7 @@ impl RocketBotPlugin for GrammarGenPlugin {
let mut state = GeneratorState::new_topmost(
my_grammar,
conditions,
Arc::clone(&self.rng),
rng,
);

let phrase = match state.generate() {
Expand Down

0 comments on commit 3c0709f

Please sign in to comment.