Skip to content

Commit

Permalink
Merge pull request #754 from felixwellen/stable-ringsolver
Browse files Browse the repository at this point in the history
Make the ring solver work in more situations
  • Loading branch information
ecavallo committed May 2, 2022
2 parents cafa8a5 + 1c43684 commit 3ad5a9f
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 71 deletions.
39 changes: 37 additions & 2 deletions Cubical/Algebra/RingSolver/Examples.agda
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,38 @@
module Cubical.Algebra.RingSolver.Examples where

open import Cubical.Foundations.Prelude
open import Cubical.Data.Int.Base hiding (_+_ ; _·_ ; _-_)
open import Cubical.Foundations.Structure

open import Cubical.Data.Int.Base hiding (ℤ; _+_ ; _·_ ; _-_)
open import Cubical.Data.List

open import Cubical.Algebra.CommRing
open import Cubical.Algebra.CommRing.Instances.Int
open import Cubical.Algebra.CommAlgebra
open import Cubical.Algebra.RingSolver.Reflection

private
variable
: Level
ℓ ℓ' : Level

module TestErrors (R : CommRing ℓ) where
open CommRingStr (snd R)

{-
The following should give an type checking error,
making the user aware that the problem is, that 'Type₀'
is not a CommRing.
-}
{-
_ : 0r ≡ 0r
_ = solve Type₀
-}

module TestWithℤ where
open CommRingStr (ℤ .snd)

_ : (a b : fst ℤ) a + b ≡ b + a
_ = solve ℤ

module Test (R : CommRing ℓ) where
open CommRingStr (snd R)
Expand Down Expand Up @@ -66,6 +88,19 @@ module Test (R : CommRing ℓ) where
_ = solve R
-}

module _ (R : CommRing ℓ) (A : CommAlgebra R ℓ') where
open CommAlgebraStr {{...}}
private
instance
_ = (snd A)
{-
The ring solver should also be able to deal with more complicated arguments
and operations with that are not given as the exact names in CommRingStr.
-}
_ : (x y : ⟨ A ⟩) x + y ≡ y + x
_ = solve (CommAlgebra→CommRing A)


module TestInPlaceSolving (R : CommRing ℓ) where
open CommRingStr (snd R)

Expand Down
197 changes: 128 additions & 69 deletions Cubical/Algebra/RingSolver/Reflection.agda
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
{-
This is inspired by/copied from:
https://github.com/agda/agda-stdlib/blob/master/src/Tactic/MonoidSolver.agda
https://github.com/agda/agda-stdlib/blob/master/src/Tactic/RingSolver.agda
Boilerplate code for calling the ring solver is constructed automatically
with agda's reflection features.
Expand Down Expand Up @@ -30,6 +31,7 @@ open import Cubical.Data.Vec using (Vec) renaming ([] to emptyVec; _∷_ to _∷

open import Cubical.Algebra.RingSolver.AlgebraExpression
open import Cubical.Algebra.CommRing
open import Cubical.Algebra.CommRing.Instances.Int using () renaming (ℤ to ℤRing)
open import Cubical.Algebra.RingSolver.RawAlgebra
open import Cubical.Algebra.RingSolver.IntAsRawRing
open import Cubical.Algebra.RingSolver.Solver renaming (solve to ringSolve)
Expand All @@ -41,6 +43,40 @@ private
_==_ = primQNameEquality
{-# INLINE _==_ #-}

record RingNames : Type where
field
is0 : Name Bool
is1 : Name Bool
is· : Name Bool
is+ : Name Bool
is- : Name Bool

getName : Term Maybe Name
getName (con c args) = just c
getName (def f args) = just f
getName _ = nothing

buildMatcher : Name Maybe Name Name Bool
buildMatcher n nothing x = n == x
buildMatcher n (just m) x = n == x or m == x

findRingNames : Term TC RingNames
findRingNames cring =
let cringStr = varg (def (quote snd) (varg cring ∷ [])) ∷ []
in do
0altName normalise (def (quote CommRingStr.0r) cringStr)
1altName normalise (def (quote CommRingStr.1r) cringStr)
·altName normalise (def (quote CommRingStr._·_) cringStr)
+altName normalise (def (quote CommRingStr._+_) cringStr)
-altName normalise (def (quote (CommRingStr.-_)) cringStr)
returnTC record {
is0 = buildMatcher (quote CommRingStr.0r) (getName 0altName) ;
is1 = buildMatcher (quote CommRingStr.1r) (getName 1altName) ;
is· = buildMatcher (quote CommRingStr._·_) (getName ·altName) ;
is+ = buildMatcher (quote CommRingStr._+_) (getName +altName) ;
is- = buildMatcher (quote (CommRingStr.-_)) (getName -altName)
}

record VarInfo : Type ℓ-zero where
field
varName : String
Expand Down Expand Up @@ -83,7 +119,7 @@ private
x ≡⟨ solve ... ⟩ (y ≡⟨ ... ⟩ z ∎)
-}
getRhs : Term Maybe Term
getRhs reasoningToTheRight@(def n xs) =
getRhs (def n xs) =
if n == (quote _∎)
then firstVisibleArg xs
else (if n == (quote _≡⟨_⟩_)
Expand Down Expand Up @@ -136,8 +172,9 @@ module pr (R : CommRing ℓ) {n : ℕ} where
1' : Expr ℤAsRawRing (fst R) n
1' = K 1

module _ (cring : Term) where
module _ (cring : Term) (names : RingNames) where
open pr
open RingNames names

`0` : List (Arg Term) Term
`0` [] = def (quote 0') (varg cring ∷ [])
Expand All @@ -152,26 +189,30 @@ module _ (cring : Term) where
`1` _ = unknown

mutual
private
op2 : Name Term Term Term
op2 op x y = con op (varg (buildExpression x) ∷ varg (buildExpression y) ∷ [])

op1 : Name Term Term
op1 op x = con op (varg (buildExpression x) ∷ [])

`_·_` : List (Arg Term) Term
`_·_` (harg _ ∷ xs) = `_·_` xs
`_·_` (varg _ ∷ varg x ∷ varg y ∷ []) =
con
(quote _·'_) (varg (buildExpression x) ∷ varg (buildExpression y) ∷ [])
`_·_` (varg x ∷ varg y ∷ []) = op2 (quote _·'_) x y
`_·_` (varg _ ∷ varg x ∷ varg y ∷ []) = op2 (quote _·'_) x y
`_·_` _ = unknown

`_+_` : List (Arg Term) Term
`_+_` (harg _ ∷ xs) = `_+_` xs
`_+_` (varg _ ∷ varg x ∷ varg y ∷ []) =
con
(quote _+'_) (varg (buildExpression x) ∷ varg (buildExpression y) ∷ [])
`_+_` (varg x ∷ varg y ∷ []) = op2 (quote _+'_) x y
`_+_` (varg _ ∷ varg x ∷ varg y ∷ []) = op2 (quote _+'_) x y
`_+_` _ = unknown

`-_` : List (Arg Term) Term
`-_` (harg _ ∷ xs) = `-_` xs
`-_` (varg _ ∷ varg x ∷ []) =
con
(quote -'_) (varg (buildExpression x) ∷ [])
`-_` (varg x ∷ []) = op1 (quote -'_) x
`-_` (varg _ ∷ varg x ∷ []) = op1 (quote -'_) x

`-_` _ = unknown

K' : List (Arg Term) Term
Expand All @@ -184,20 +225,20 @@ module _ (cring : Term) where
buildExpression : Term Term
buildExpression (var index _) = con (quote ∣) (varg (finiteNumberAsTerm index) ∷ [])
buildExpression t@(def n xs) =
switch (n ==_) cases
case (quote CommRingStr.0r) ⇒ `0` xs break
case (quote CommRingStr.1r) ⇒ `1` xs break
case (quote CommRingStr._·_) ⇒ `_·_` xs break
case (quote CommRingStr._+_) ⇒ `_+_` xs break
case (quote (CommRingStr.-_)) ⇒ `-_` xs break
switch (λ f f n) cases
case is0 ⇒ `0` xs break
case is1 ⇒ `1` xs break
case is· ⇒ `_·_` xs break
case is+ ⇒ `_+_` xs break
case is- ⇒ `-_` xs break
default⇒ (K' xs)
buildExpression t@(con n xs) =
switch (n ==_) cases
case (quote CommRingStr.0r) ⇒ `0` xs break
case (quote CommRingStr.1r) ⇒ `1` xs break
case (quote CommRingStr._·_) ⇒ `_·_` xs break
case (quote CommRingStr._+_) ⇒ `_+_` xs break
case (quote (CommRingStr.-_)) ⇒ `-_` xs break
switch (λ f f n) cases
case is0 ⇒ `0` xs break
case is1 ⇒ `1` xs break
case is· ⇒ `_·_` xs break
case is+ ⇒ `_+_` xs break
case is- ⇒ `-_` xs break
default⇒ (K' xs)
buildExpression t = unknown

Expand All @@ -206,9 +247,34 @@ module _ (cring : Term) where
toAlgebraExpression (just (lhs , rhs)) = just (buildExpression lhs , buildExpression rhs)

private
adjustDeBruijnIndex : (n : ℕ) Term Term
adjustDeBruijnIndex n (var k args) = var (k +ℕ n) args
adjustDeBruijnIndex n _ = unknown

holeMalformedError : {A : Type ℓ} Term TC A
holeMalformedError hole′ = typeError
(strErr "Something went wrong when getting the variable names in "
∷ termErr hole′ ∷ [])

astExtractionError : {A : Type ℓ} Term TC A
astExtractionError equation = typeError
(strErr "Error while trying to build ASTs for the equation "
termErr equation ∷ [])

variableExtractionError : {A : Type ℓ} Term TC A
variableExtractionError varsToSolve = typeError
(strErr "Error reading variables to solve "
termErr varsToSolve ∷
[])


mutual
{- this covers just some common cases and should be refined -}
adjustDeBruijnIndex : (n : ℕ) Term Term
adjustDeBruijnIndex n (var k args) = var (k +ℕ n) args
adjustDeBruijnIndex n (def m l) = def m (map (adjustDeBruijnArg n) l)
adjustDeBruijnIndex n _ = unknown

adjustDeBruijnArg : (n : ℕ) Arg Term Arg Term
adjustDeBruijnArg n (arg i (var k args)) = arg i (var (k +ℕ n) args)
adjustDeBruijnArg n (arg i x) = arg i x

extractVarIndices : Maybe (List Term) Maybe (List ℕ)
extractVarIndices (just ((var index _) ∷ l)) with extractVarIndices (just l)
Expand All @@ -217,26 +283,27 @@ private
extractVarIndices (just []) = just []
extractVarIndices _ = nothing

getVarsAndEquation : Term Maybe (List VarInfo × Term)
listToVec : {A : Type ℓ} (l : List A) Vec A (length l)
listToVec [] = emptyVec
listToVec (x ∷ l) = x ∷vec listToVec l

getVarsAndEquation : Term List VarInfo × Term
getVarsAndEquation t =
let
(rawVars , equationTerm) = extractVars t
maybeVars = addIndices (length rawVars) rawVars
in map-Maybe (_, equationTerm) maybeVars
let (rawVars , equationTerm) = extractVars t
vars = addIndices (length rawVars) (listToVec rawVars)
in (vars , equationTerm)
where
extractVars : Term List (String × Arg Term) × Term
extractVars (pi argType (abs varName t)) with extractVars t
... | xs , equation
= (varName , argType) ∷ xs , equation
extractVars equation = [] , equation

addIndices : List (String × Arg Term) Maybe (List VarInfo)
addIndices ℕ.zero [] = just []
addIndices (ℕ.suc countVar) ((varName , argType) ∷ list) =
map-Maybe (λ varList record { varName = varName ; varType = argType ; index = countVar }
∷ varList)
(addIndices countVar list)
addIndices _ _ = nothing
addIndices : (n : ℕ) Vec (String × Arg Term) n List VarInfo
addIndices ℕ.zero emptyVec = []
addIndices (ℕ.suc countVar) ((varName , argType) ∷vec list) =
record { varName = varName ; varType = argType ; index = countVar }
∷ (addIndices countVar list)

toListOfTerms : Term Maybe (List Term)
toListOfTerms (con c []) = if (c == (quote [])) then just [] else nothing
Expand All @@ -246,64 +313,56 @@ private
toListOfTerms (con c (harg t ∷ args)) = toListOfTerms (con c args)
toListOfTerms _ = nothing

checkIsRing : Term TC Term
checkIsRing ring = checkType ring (def (quote CommRing) (varg unknown ∷ []))

solve-macro : Term Term TC Unit
solve-macro cring hole =
solve-macro uncheckedCommRing hole =
do
commRing checkIsRing uncheckedCommRing
hole′ inferType hole >>= normalise
just (varInfos , equation) returnTC (getVarsAndEquation hole′)
where
nothing
typeError (strErr "Something went wrong when getting the variable names in "
∷ termErr hole′ ∷ [])
names findRingNames commRing
(varInfos , equation) returnTC (getVarsAndEquation hole′)

{-
The call to the ring solver will be inside a lamba-expression.
That means, that we have to adjust the deBruijn-indices of the variables in 'cring'
-}
adjustedCring returnTC (adjustDeBruijnIndex (length varInfos) cring)
just (lhs , rhs) returnTC (toAlgebraExpression adjustedCring (getArgs equation))
where
nothing
typeError(
strErr "Error while trying to build ASTs for the equation "
termErr equation ∷ [])
let solution = solverCallWithLambdas (length varInfos) varInfos adjustedCring lhs rhs
adjustedCommRing returnTC (adjustDeBruijnIndex (length varInfos) commRing)
just (lhs , rhs) returnTC (toAlgebraExpression adjustedCommRing names (getArgs equation))
where nothing astExtractionError equation

let solution = solverCallWithLambdas (length varInfos) varInfos adjustedCommRing lhs rhs
unify hole solution

solveInPlace-macro : Term Term Term TC Unit
solveInPlace-macro cring varsToSolve hole =
do
equation inferType hole >>= normalise
names findRingNames cring
just varIndices returnTC (extractVarIndices (toListOfTerms varsToSolve))
where
nothing
typeError(
strErr "Error reading variables to solve "
termErr varsToSolve ∷ [])
just (lhs , rhs) returnTC (toAlgebraExpression cring (getArgs equation))
where
nothing
typeError(
strErr "Error while trying to build ASTs for the equation "
termErr equation ∷ [])
where nothing variableExtractionError varsToSolve

just (lhs , rhs) returnTC (toAlgebraExpression cring names (getArgs equation))
where nothing astExtractionError equation

let solution = solverCallByVarIndices (length varIndices) varIndices cring lhs rhs
unify hole solution

solveEqReasoning-macro : Term Term Term Term Term TC Unit
solveEqReasoning-macro lhs cring varsToSolve reasoningToTheRight hole =
do
names findRingNames cring
just varIndices returnTC (extractVarIndices (toListOfTerms varsToSolve))
where
nothing
typeError(
strErr "Error reading variables to solve "
termErr varsToSolve ∷ [])
where nothing variableExtractionError varsToSolve

just rhs returnTC (getRhs reasoningToTheRight)
where
nothing
typeError(
strErr "Failed to extract right hand side of equation to solve from "
termErr reasoningToTheRight ∷ [])
just (lhsAST , rhsAST) returnTC (toAlgebraExpression cring (just (lhs , rhs)))
just (lhsAST , rhsAST) returnTC (toAlgebraExpression cring names (just (lhs , rhs)))
where
nothing
typeError(
Expand Down

0 comments on commit 3ad5a9f

Please sign in to comment.