Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
mbellotti committed Jan 28, 2023
1 parent c2ee9ee commit be38b30
Show file tree
Hide file tree
Showing 5 changed files with 319 additions and 51 deletions.
12 changes: 12 additions & 0 deletions smt/asserts-new.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package smt

func (g *Generator) parseAssert(assert ast.Node) ([]*assrt, []*assrt, string) {
switch e := assert.(type) {
case *ast.AssertionStatement:
case *ast.AssumptionStatement:
default:
pos := e.Position()
panic(fmt.Sprintf("not a valid assert or assumption line: %d, col: %d", pos[0], pos[1]))
}
}

240 changes: 192 additions & 48 deletions smt/asserts.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,31 @@ import (
"strings"
)

func (g *Generator) parseAssert(assert ast.Node) ([]*assrt, []*assrt, string) {
switch e := assert.(type) {
case *ast.AssertionStatement:
a1 := g.generateAsserts(e.Constraints.Left, e.Constraints.Operator, e.Constraints, e)
a2 := g.generateAsserts(e.Constraints.Right, e.Constraints.Operator, e.Constraints, e)

if e.Constraints.Operator == "&&" || e.Constraints.Operator == "||" {
return a1, a2, e.Constraints.Operator
} else {
a2 = removeDuplicates(a1, a2)
return append(a1, a2...), nil, ""
}
case *ast.AssumptionStatement:
a1 := g.generateAsserts(e.Constraints.Left, e.Constraints.Operator, e.Constraints, e)
a2 := g.generateAsserts(e.Constraints.Right, e.Constraints.Operator, e.Constraints, e)
if e.Constraints.Operator == "&&" || e.Constraints.Operator == "||" {
return a1, a2, e.Constraints.Operator
} else {
return append(a1, a2...), nil, ""
}
default:
pos := e.Position()
panic(fmt.Sprintf("not a valid assert or assumption line: %d, col: %d", pos[0], pos[1]))
}
}
// func (g *Generator) parseAssert(assert ast.Node) ([]*assrt, []*assrt, string) {
// switch e := assert.(type) {
// case *ast.AssertionStatement:
// a1 := g.generateAsserts(e.Constraints.Left, e.Constraints.Operator, e.Constraints, e)
// a2 := g.generateAsserts(e.Constraints.Right, e.Constraints.Operator, e.Constraints, e)

// if e.Constraints.Operator == "&&" || e.Constraints.Operator == "||" {
// return a1, a2, e.Constraints.Operator
// } else {
// a2 = removeDuplicates(a1, a2)
// return append(a1, a2...), nil, ""
// }
// case *ast.AssumptionStatement:
// a1 := g.generateAsserts(e.Constraints.Left, e.Constraints.Operator, e.Constraints, e)
// a2 := g.generateAsserts(e.Constraints.Right, e.Constraints.Operator, e.Constraints, e)
// if e.Constraints.Operator == "&&" || e.Constraints.Operator == "||" {
// return a1, a2, e.Constraints.Operator
// } else {
// return append(a1, a2...), nil, ""
// }
// default:
// pos := e.Position()
// panic(fmt.Sprintf("not a valid assert or assumption line: %d, col: %d", pos[0], pos[1]))
// }
// }

func (g *Generator) generateAsserts(exp ast.Expression, comp string, constr ast.Expression, stmt ast.Statement) []*assrt {
var ident []string
Expand All @@ -46,6 +46,12 @@ func (g *Generator) generateAsserts(exp ast.Expression, comp string, constr ast.
assrt = append(assrt, g.packageAssert(id, comp, v, stmt))
}
return assrt
case *ast.PrefixExpression:
ident = g.findIdent(v)
for _, id := range ident {
assrt = append(assrt, g.packageAssert(id, comp, v, stmt))
}
return assrt
case *ast.Identifier:
ident = g.findIdent(v)
for _, id := range ident {
Expand Down Expand Up @@ -166,6 +172,17 @@ func (g *Generator) parseInvariant(ex ast.Expression) rule {
constant: true,
}
case *ast.PrefixExpression:
right := g.parseInvariant(e.Right)
i := &invariant{
left: nil,
operator: smtlibOperators(e.Operator),
right: right,
}
if e.Operator == "!" { //Not valid in SMTLib
return &invariant{operator: "not",
right: i}
}
return i
case *ast.Nil:
case *ast.IndexExpression:
return &wrap{value: g.convertIndexExpr(e),
Expand All @@ -180,37 +197,164 @@ func (g *Generator) parseInvariant(ex ast.Expression) rule {
return nil
}

type thenStates struct {
roundClauses []string
values [][][]string
}

func (g *Generator) generateThenRules(inv *invariant) []string {
var rounds [][]int
var base string
switch when := inv.left.(type) {
case *wrap:
base = when.value
rounds = g.lookupVarRounds(when.value, when.state)
}
when := g.whenInfixNode(inv.left)
//then := g.thenInfixNode(inv.right)
fmt.Println(when)
//fmt.Println(then)
return []string{}
}

var rules []string
switch then := inv.right.(type) {
case *wrap:
for _, r := range rounds {
values := g.RoundVars[r[1]][r[2]:]
var or []string
for _, v := range values {
if v[0] == then.value {
vname := strings.Join(v, "_")
or = append(or, fmt.Sprintf("(= %s %s)", vname, "true"))
}
}
wclause := fmt.Sprintf("%s_%d", base, r[0])
tclause := fmt.Sprintf("(or %s)", strings.Join(or, " "))
// check rounds for any matching variable names

rules = append(rules, fmt.Sprintf("(and %s %s)", wclause, tclause))
// Generate all permutations of variables in the assert,
// in the round, in between variable state change

func (g *Generator) whenInfixNode(ru rule) map[string]*thenStates {
ret := make(map[string]*thenStates)
switch r := ru.(type) {
case *invariant:
left := g.whenInfixNode(r.left)
for k, v := range left {
ret[k] = v
}

right := g.whenInfixNode(r.right)
for k, v := range right {
ret[k] = v
}

return ret
case *wrap:
roundClause, values := g.whenNode(r)
ret[r.value] = &thenStates{
roundClauses: roundClause,
values: values,
}
return ret
}
return rules
return ret
}

func (g *Generator) whenNode(when *wrap) ([]string, [][][]string) {
var roundClauses []string
var values [][][]string
base := when.value
rounds := g.lookupVarRounds(when.value, when.state)
for _, r := range rounds {
roundClauses = append(roundClauses, fmt.Sprintf("(= %s_%d %s)", base, r[0], "true"))
values = append(values, g.RoundVars[r[1]][r[2]:])
}
return roundClauses, values
}

// func (g *Generator) thenNode(ru rule) {

// }

// func (g *Generator) generateThenRules(inv *invariant) []string {
// var rounds [][]int
// var roundClauses []string
// var values [][][]string
// var base string
// switch when := inv.left.(type) {
// case *invariant:
// if when.left == nil {
// base = when.right.(*wrap).value
// rounds = g.lookupVarRounds(base, when.right.(*wrap).state)
// for _, r := range rounds {
// roundClauses = append(roundClauses, fmt.Sprintf("(= %s_%d %s)", base, r[0], "false"))
// values = append(values, g.RoundVars[r[1]][r[2]:])
// }
// }
// case *wrap:
// base = when.value
// rounds = g.lookupVarRounds(when.value, when.state)
// for _, r := range rounds {
// roundClauses = append(roundClauses, fmt.Sprintf("(= %s_%d %s)", base, r[0], "true"))
// values = append(values, g.RoundVars[r[1]][r[2]:])
// }
// }

// var rules []string
// switch then := inv.right.(type) {
// case *invariant:
// if then.left == nil { //Prefix
// rules = g.constructThen(then.left.(*wrap), values, roundClauses, "false")
// }
// case *wrap:
// rules = g.constructThen(then, values, roundClauses, "true")
// }
// return rules
// }

// func (g *Generator) whenClauses(when *wrap) ([][]int, []string, [][][]string) {
// var roundClauses []string
// var values [][][]string
// base := when.value
// rounds := g.lookupVarRounds(when.value, when.state)
// for _, r := range rounds {
// roundClauses = append(roundClauses, fmt.Sprintf("(= %s_%d %s)", base, r[0], "true"))
// values = append(values, g.RoundVars[r[1]][r[2]:])
// }
// return rounds, roundClauses, values
// }

// func (g *Generator) thenInfix(ru rule) []string {
// switch when := ru.(type) {
// case *invariant:
// leftRounds, leftRC, leftValues := g.thenInfixNode(when.left)
// rightRounds, rightRC, rightValues := g.thenInfixNode(when.right)

// case *wrap:
// rounds, roundClauses, values := g.whenClauses(when)
// default:
// panic("unsupported rule")
// }
// }

// func (G *Generator) thenClauses(then *wrap, values [][][]string) []string {
// var or []string
// var rules []string
// for _, val := range values {
// for _, v := range val {
// if v[0] == then.value {
// vname := strings.Join(v, "_")
// or = append(or, fmt.Sprintf("(= %s %s)", vname, b))
// }
// }
// tclause := fmt.Sprintf("(or %s)", strings.Join(or, " "))

// rules = append(rules, tclause)

// }
// return rules
// }

// func (g *Generator) constructThen(then *wrap, values [][][]string, roundClauses []string, b string) []string {
// var or []string
// var rules []string
// for idx, val := range values {
// for _, v := range val {
// if v[0] == then.value {
// vname := strings.Join(v, "_")
// or = append(or, fmt.Sprintf("(= %s %s)", vname, b))
// }
// }
// wclause := roundClauses[idx]
// tclause := fmt.Sprintf("(or %s)", strings.Join(or, " "))

// rules = append(rules, fmt.Sprintf("(and %s %s)", wclause, tclause))

// }
// return rules
// }

func (g *Generator) packageAssert(ident string, comp string, expr ast.Expression, stmt ast.Statement) *assrt {
var temporalFilter string
var temporalN int
Expand Down
3 changes: 3 additions & 0 deletions smt/rules_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ type invariant struct {

func (i *invariant) ruleNode() {}
func (i *invariant) String() string {
if i.left == nil { //Prefixes like !a
return fmt.Sprint(i.operator, i.right.String())
}
return fmt.Sprint(i.left.String(), i.operator, i.right.String())
}
func (i *invariant) Tag(k1 string, k2 string) {
Expand Down
2 changes: 1 addition & 1 deletion smt/smt_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ func prepTestSys(filepath string, test string, imports bool) (string, error) {
generator := NewGenerator()
generator.LoadMeta(compiler.Uncertains, compiler.Unknowns, compiler.Asserts, compiler.Assumes)
generator.Run(compiler.GetIR())
//fmt.Println(generator.SMT())
fmt.Println(generator.SMT())
return generator.SMT(), nil
}

Expand Down

0 comments on commit be38b30

Please sign in to comment.