Skip to content

Commit

Permalink
Arc<> in Sort (circify#200)
Browse files Browse the repository at this point in the history
Decreases memory use in matrix multiply by up to 2.5x
  • Loading branch information
alex-ozdemir committed Jun 24, 2024
1 parent 4aa36e4 commit 4c3a1a5
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/ir/opt/mem/ram/checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ fn derivative_gcd(
let pairs = term(
Op::Array(Box::new(ArrayOp {
key: fs.clone(),
val: Sort::Tuple(Box::new([fs.clone(), Sort::Bool])),
val: Sort::new_tuple(vec![fs.clone(), Sort::Bool]),
})),
values
.clone()
Expand Down
5 changes: 3 additions & 2 deletions src/ir/term/dist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,12 @@ impl FixedSizeDist {
match op {
Op::Ite => vec![Sort::Bool, sort.clone(), sort.clone()],
o if o.arity() == Some(0) => vec![],
Op::Field(i) => vec![if let Sort::Tuple(mut ss) =
Op::Field(i) => vec![if let Sort::Tuple(ss) =
self.sample_tuple_sort(*i + 1, self.size - 1, rng)
{
let mut ss = (*ss).to_vec();
ss[*i] = sort.clone();
Sort::Tuple(ss)
Sort::new_tuple(ss)
} else {
unreachable!()
}],
Expand Down
2 changes: 1 addition & 1 deletion src/ir/term/ext/poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
"UniqDeriGcd pairs: second element must be a bool",
)?;
let arr = Sort::new_array(f.clone(), f.clone(), size);
Ok(Sort::Tuple(Box::new([arr.clone(), arr])))
Ok(Sort::new_tuple(vec![arr.clone(), arr]))
} else {
// non-pair entries value
Err(TypeErrorReason::Custom(
Expand Down
6 changes: 3 additions & 3 deletions src/ir/term/ext/ram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ pub fn check(arg_sorts: &[&Sort]) -> Result<Sort, TypeErrorReason> {
let f = pf_or(i_value, "PersistentRamSplit indices")?;
let n_touched = i_size.min(size);
let n_ignored = size - n_touched;
let f_pair = Sort::Tuple(Box::new([f.clone(), f.clone()]));
let f_pair = Sort::new_tuple(vec![f.clone(), f.clone()]);
let ignored_entries_sort = Sort::Tuple(vec![f_pair.clone(); n_ignored].into());
let selected_entries_sort = Sort::Tuple(vec![f_pair.clone(); n_touched].into());
Ok(Sort::Tuple(Box::new([
Ok(Sort::new_tuple(vec![
ignored_entries_sort,
selected_entries_sort.clone(),
selected_entries_sort,
])))
]))
}

/// Evaluate [super::ExtOp::PersistentRamSplit].
Expand Down
16 changes: 11 additions & 5 deletions src/ir/term/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::borrow::Borrow;
use std::cell::Cell;
use std::collections::BTreeMap;
use std::sync::Arc;

pub mod bv;
pub mod dist;
Expand Down Expand Up @@ -798,11 +799,11 @@ pub enum Sort {
/// Array from one sort to another, of fixed size.
///
/// size presumes an order, and a zero, for the key sort.
Array(Box<ArraySort>),
Array(Arc<ArraySort>),
/// Map from one sort to another.
Map(Box<MapSort>),
Map(Arc<MapSort>),
/// A tuple
Tuple(Box<[Sort]>),
Tuple(Arc<[Sort]>),
}

#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
Expand Down Expand Up @@ -872,6 +873,11 @@ impl Sort {
}
}

/// Create a new tuple sort
pub fn new_tuple(sorts: Vec<Sort>) -> Self {
Self::Tuple(Arc::from(sorts.into_boxed_slice()))
}

#[track_caller]
/// Unwrap the constituent sorts of this array, panicking otherwise.
pub fn as_array(&self) -> (&Sort, &Sort, usize) {
Expand All @@ -884,7 +890,7 @@ impl Sort {

/// Create a new array sort
pub fn new_array(key: Sort, val: Sort, size: usize) -> Self {
Self::Array(Box::new(ArraySort { key, val, size }))
Self::Array(Arc::new(ArraySort { key, val, size }))
}

/// Is this an array?
Expand All @@ -909,7 +915,7 @@ impl Sort {

/// Create a new map sort
pub fn new_map(key: Sort, val: Sort) -> Self {
Self::Map(Box::new(MapSort { key, val }))
Self::Map(Arc::new(MapSort { key, val }))
}

/// The nth element of this sort.
Expand Down
2 changes: 1 addition & 1 deletion src/target/smt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ mod test {
fn tuple_is_sat() {
let t = term![Op::Eq; term![Op::Field(0); term![Op::Tuple; bv_lit(0,4), bv_lit(5,6)]], var("a".into(), Sort::BitVector(4))];
assert!(check_sat(&t));
let t = term![Op::Eq; term![Op::Tuple; bv_lit(0,4), bv_lit(5,6)], var("a".into(), Sort::Tuple(vec![Sort::BitVector(4), Sort::BitVector(6)].into_boxed_slice()))];
let t = term![Op::Eq; term![Op::Tuple; bv_lit(0,4), bv_lit(5,6)], var("a".into(), Sort::new_tuple(vec![Sort::BitVector(4), Sort::BitVector(6)]))];
assert!(check_sat(&t));
}

Expand Down

0 comments on commit 4c3a1a5

Please sign in to comment.