Skip to content

Commit

Permalink
probably got assert tracking all done now?
Browse files Browse the repository at this point in the history
  • Loading branch information
mbellotti committed Aug 22, 2023
1 parent 42c457c commit 638019a
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 48 deletions.
134 changes: 104 additions & 30 deletions smt/asserts.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ func (g *Generator) parseAssert(a *ast.AssertionStatement) string {
dset := util.DiffStrSets(left.Bases, right.Bases)
if dset.Len() == 0 && (a.Temporal != "" || a.TemporalFilter != "") {
sg := g.mergeInvariantInfix(left, right, smtlibOperators(a.Constraint.Operator))
ir := g.flattenStates(sg)
return g.applyTemporalLogic(a.Temporal, ir, a.TemporalFilter, on, off)
ir, chain := g.flattenStates(sg)
return g.applyTemporalLogic(a.Temporal, g.NewAssertChain(ir, chain, ""), a.TemporalFilter, on, off)
}

if a.Temporal != "" || a.TemporalFilter != "" {
ir := g.expandAssertStateGraph(g.flattenStates(left), g.flattenStates(right), smtlibOperators(a.Constraint.Operator), a.TemporalFilter, a.TemporalN)
return g.applyTemporalLogic(a.Temporal, ir.Values, a.TemporalFilter, on, off)
ir := g.expandAssertStateGraph(left, right, smtlibOperators(a.Constraint.Operator), a.TemporalFilter, a.TemporalN)
return g.applyTemporalLogic(a.Temporal, ir, a.TemporalFilter, on, off)
}
if a.Assume {
operator := "and"
Expand Down Expand Up @@ -97,7 +97,7 @@ func (g *Generator) parseInvariantNode(exp ast.Expression, stateRange bool) *rul
return wg
case *ast.IntegerLiteral:
s := make(map[int]*rules.AssertChain)
s[0] = &rules.AssertChain{Values: []string{fmt.Sprint(e.Value)}}
s[0] = g.NewAssertChain([]string{fmt.Sprint(e.Value)}, []int{}, "")
sg := rules.NewStateGroup()
sg.AddWrap(&rules.States{
Base: "__int",
Expand All @@ -106,7 +106,7 @@ func (g *Generator) parseInvariantNode(exp ast.Expression, stateRange bool) *rul
return sg
case *ast.FloatLiteral:
s := make(map[int]*rules.AssertChain)
s[0] = &rules.AssertChain{Values: []string{fmt.Sprint(e.Value)}}
s[0] = g.NewAssertChain([]string{fmt.Sprint(e.Value)}, []int{}, "")
sg := rules.NewStateGroup()
sg.AddWrap(&rules.States{
Base: "__float",
Expand All @@ -116,7 +116,7 @@ func (g *Generator) parseInvariantNode(exp ast.Expression, stateRange bool) *rul
return sg
case *ast.Boolean:
s := make(map[int]*rules.AssertChain)
s[0] = &rules.AssertChain{Values: []string{fmt.Sprint(e.Value)}}
s[0] = g.NewAssertChain([]string{fmt.Sprint(e.Value)}, []int{}, "")
sg := rules.NewStateGroup()
sg.AddWrap(&rules.States{
Base: "__bool",
Expand All @@ -126,7 +126,7 @@ func (g *Generator) parseInvariantNode(exp ast.Expression, stateRange bool) *rul
return sg
case *ast.StringLiteral:
s := make(map[int]*rules.AssertChain)
s[0] = &rules.AssertChain{Values: []string{fmt.Sprint(e.Value)}}
s[0] = g.NewAssertChain([]string{fmt.Sprint(e.Value)}, []int{}, "")
sg := rules.NewStateGroup()
sg.AddWrap(&rules.States{
Base: "__string",
Expand Down Expand Up @@ -190,7 +190,7 @@ func (g *Generator) mergeInvariantPrefix(right []*rules.States, operator string)
for i := 0; i <= g.Rounds; i++ {
if st, ok := r.States[i]; ok {
for _, s := range st.Values {
states[i] = &rules.AssertChain{Chain: st.Chain}
states[i] = g.NewAssertChain([]string{}, st.Chain, "")
states[i].Values = append(states[i].Values, fmt.Sprintf("(%s %s)", operator, s))
}
}
Expand Down Expand Up @@ -227,7 +227,7 @@ func (g *Generator) mergeByRound(left *rules.States, right *rules.States, operat
st := make(map[int]*rules.AssertChain)
if left.Constant && right.Constant {
combos := util.PairCombinations(left.States[0].Values, right.States[0].Values)
st[0] = g.packageStateGraph(combos, operator)
st[0] = g.packageStateGraph(combos, operator, left.States[0].Chain, [][]int{})
ret.States = st
return ret
}
Expand All @@ -253,30 +253,44 @@ func (g *Generator) mergeByRound(left *rules.States, right *rules.States, operat
//Pair based on same state
for i := 0; i <= g.Rounds; i++ {
var pairs [][]string
var slast []string
var slast *rules.AssertChain

if _, ok := long[i]; !ok {
long[i] = &rules.AssertChain{}
}

var chains []int
for idx, s := range long[i].Values {
if sstates, ok := short[i]; ok {
slast = sstates.Values
slast = sstates
if len(sstates.Values) > idx {
p := g.mergePairs(s, sstates.Values[idx], leftLead)
pairs = append(pairs, p)
if len(long[i].Chain) > 0 {
chains = append(chains, long[i].Chain[idx])
}
if len(short[i].Chain) > 0 {
chains = append(chains, short[i].Chain[idx])
}
continue
}

p := g.mergePairs(s, sstates.Values[len(sstates.Values)-1], leftLead)
if len(long[i].Chain) > 0 {
chains = append(chains, long[i].Chain[idx])
}
if len(sstates.Chain) > 0 {
chains = append(chains, sstates.Chain[len(sstates.Chain)-1])
}
pairs = append(pairs, p)
continue
}
p := g.mergePairs(s, slast[len(slast)-1], leftLead)
p := g.mergePairs(s, slast.Values[len(slast.Values)-1], leftLead)
chains = append(chains, []int{long[i].Chain[idx], slast.Chain[len(slast.Chain)-1]}...)
pairs = append(pairs, p)

}
st[i] = g.packageStateGraph(pairs, operator)
st[i] = g.packageStateGraph(pairs, operator, chains, [][]int{})
}
ret.States = st
return ret
Expand All @@ -297,9 +311,10 @@ func (g *Generator) mergeByRound(left *rules.States, right *rules.States, operat
}

if left.Terminal && right.Terminal {
var chains []int
combos := g.termCombos(left.Base, right.Base)
for i, c := range combos {
st[i] = g.packageStateGraph(c, operator)
st[i] = g.packageStateGraph(c, operator, chains, [][]int{})
}
ret.States = st
return ret
Expand Down Expand Up @@ -332,14 +347,48 @@ func (g *Generator) mergeByRound(left *rules.States, right *rules.States, operat
}

combos := util.PairCombinations(l.Values, r.Values)
st[i] = g.packageStateGraph(combos, operator)
chains := g.matchChainToCombo(left.GetChains(), right.GetChains(), combos)
st[i] = g.packageStateGraph(combos, operator, []int{}, chains)
llast = l.Values[len(l.Values)-1:]
rlast = r.Values[len(r.Values)-1:]
}
ret.States = st
return ret
}

func (g *Generator) matchChainToCombo(left []int, right []int, combos [][]string) [][]int {
var ret [][]int
merge := append(left, right...)
lookup := make(map[string]int)

for _, c := range combos {
var item []int
if l1, ok := lookup[c[0]]; !ok {
for _, m := range merge {
if g.Log.Asserts[m].String() == c[0] {
item = append(item, m)
lookup[c[0]] = m
}
}
} else {
item = append(item, l1)
}

if l2, ok := lookup[c[1]]; !ok {
for _, m := range merge {
if g.Log.Asserts[m].String() == c[1] {
item = append(item, m)
lookup[c[1]] = m
}
}
} else {
item = append(item, l2)
}
ret = append(ret, item)
}
return ret
}

func (g *Generator) termCombos(lbase string, rbase string) map[int][][]string {
var llast string
var rlast string
Expand Down Expand Up @@ -374,30 +423,34 @@ func (g *Generator) balance(vr *rules.States, con *rules.States, operator string
for i := 0; i <= g.Rounds; i++ {
if v, ok := vr.States[i]; ok {
combos := util.PairCombinations(v.Values, con.States[0].Values)
ret[i] = g.packageStateGraph(combos, operator)
ret[i] = g.packageStateGraph(combos, operator, v.Chain, [][]int{})
}
}
return ret
}

func (g *Generator) flattenStates(sg *rules.StateGroup) []string {
func (g *Generator) flattenStates(sg *rules.StateGroup) ([]string, []int) {
var asserts []string
var chains []int
for _, w := range sg.Wraps {
for i := 0; i <= g.Rounds; i++ {
if s, ok := w.States[i]; ok {
asserts = append(asserts, s.Values...)
chains = append(chains, s.Chain...)
}
}
}
return asserts
return asserts, chains
}

func (g *Generator) joinStates(sg *rules.StateGroup, operator string) string {
asserts := g.flattenStates(sg)
asserts, chains := g.flattenStates(sg)
if len(asserts) == 1 {
return asserts[0]
}
return g.writeAssertlessRule(operator, strings.Join(asserts, " "), "")
ret := g.writeAssertlessRule(operator, strings.Join(asserts, " "), "")
g.Log.AssertChains[ret] = g.NewAssertChain(asserts, chains, operator)
return ret
}

func invalidBase(base string) bool {
Expand Down Expand Up @@ -529,9 +582,12 @@ func (g *Generator) convertIndexExpr(idx *ast.IndexExpression) string {
return strings.Join([]string{idx.Left.String(), idx.Index.String()}, "_")
}

func (g *Generator) expandAssertStateGraph(list1 []string, list2 []string, op string, temporalFilter string, temporalN int) *rules.AssertChain {
func (g *Generator) expandAssertStateGraph(left *rules.StateGroup, right *rules.StateGroup, op string, temporalFilter string, temporalN int) *rules.AssertChain {
var x [][]string
list1, chain1 := g.flattenStates(left)
list2, chain2 := g.flattenStates(right)
c := util.Cartesian(list1, list2)
chains := append(chain1, chain2...)
switch temporalFilter {
// For logic like "no more than X times" "no fewer than X times"
// We need to flip some of the operators and build out more
Expand All @@ -543,36 +599,50 @@ func (g *Generator) expandAssertStateGraph(list1 []string, list2 []string, op st
for _, p := range pairs {
var o []string
var f []string
var chainOn []int
var chainOff []int
for _, on := range p[0] {
// Write the clauses
i := g.Log.NewAssert(on[0], on[1], op)
chainOn = append(chainOn, i)
o = append(o, fmt.Sprintf("(%s %s %s)", op, on[0], on[1]))
}
// For nmt any of the potential on states can be on
var onStr string
if len(o) == 1 {
g.Log.AssertChains[o[0]] = g.NewAssertChain(o, chainOn, op)
onStr = o[0]
} else {
onStr = fmt.Sprintf("(%s %s)", "or", strings.Join(o, " "))
clause := strings.Join(o, " ")
g.Log.AssertChains[clause] = g.NewAssertChain(o, chainOn, "or")
onStr = fmt.Sprintf("(%s %s)", "or", clause)
}

offOp := llvm.OP_NEGATE[op]
for _, off := range p[1] {
if op == "=" {
i := g.Log.NewAssert(off[0], off[1], "!=")
chainOff = append(chainOff, i)
f = append(f, fmt.Sprintf("(%s (%s %s %s))", "not", op, off[0], off[1]))
} else {
i := g.Log.NewAssert(off[0], off[1], offOp)
chainOff = append(chainOff, i)
f = append(f, fmt.Sprintf("(%s %s %s)", offOp, off[0], off[1]))
}
}
// But these states must be off
var offStr string
if len(f) == 1 {
g.Log.AssertChains[f[0]] = g.NewAssertChain(f, chainOff, "")
offStr = f[0]
} else {
offStr = fmt.Sprintf("(%s %s)", "and", strings.Join(f, " "))
clause := strings.Join(f, " ")
g.Log.AssertChains[clause] = g.NewAssertChain(f, chainOff, "and")
offStr = fmt.Sprintf("(%s %s)", "and", clause)
}
x = append(x, []string{onStr, offStr})
}
return g.packageStateGraph(x, "and")
return g.packageStateGraph(x, "and", chains, [][]int{})
case "nft":
// (or (and on on on))
combos := util.Combinations(c, temporalN)
Expand All @@ -591,16 +661,19 @@ func (g *Generator) expandAssertStateGraph(list1 []string, list2 []string, op st
}
x = append(x, []string{onStr})
}
return g.packageStateGraph(x, "or")
return g.packageStateGraph(x, "or", chains, [][]int{})
default:
return g.packageStateGraph(c, op)
return g.packageStateGraph(c, op, chains, [][]int{})
}
}

func (g *Generator) packageStateGraph(x [][]string, op string) *rules.AssertChain {
func (g *Generator) packageStateGraph(x [][]string, op string, subchain []int, subchains [][]int) *rules.AssertChain {
var product []string
var chain []int
for _, a := range x {
for idx, a := range x {
if len(subchains) > 0 {
subchain = subchains[idx]
}
if len(a) == 1 {
product = append(product, a[0])
} else {
Expand All @@ -619,12 +692,13 @@ func (g *Generator) packageStateGraph(x [][]string, op string) *rules.AssertChai

s = fmt.Sprintf("(%s %s %s)", op, a[0], a[1])
}
g.Log.AssertChains[s] = g.NewAssertChain(a, subchain, op)
i := g.Log.NewAssert(a[0], a[1], op)
chain = append(chain, i)
product = append(product, s)
}
}
return &rules.AssertChain{Values: product, Chain: chain}
return g.NewAssertChain(product, chain, "")
}

func smtlibOperators(op string) string {
Expand Down
2 changes: 1 addition & 1 deletion smt/asserts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func TestSimpleAssert(t *testing.T) {
t.Fatalf("wrong number of asserts in the event log")
}

if g.Log.Asserts[0].Right.Type() != "FLOAT" || g.Log.Asserts[0].Right.GetFloat() != 0 {
if g.Log.Asserts[0].Right.Type() != "INT" || g.Log.Asserts[0].Right.GetInt() != 0 {
t.Fatalf("wrong right value in the first assert")
}

Expand Down

0 comments on commit 638019a

Please sign in to comment.