Browse files

[fix] stdlib, types: fix the wip unification from rudy (especialy row…

…s and cols)
  • Loading branch information...
1 parent d8cacd5 commit eedbe5441501b17a523194358156eb51331e7d1e @BourgerieQuentin BourgerieQuentin committed Feb 17, 2012
View
2 stdlib/unification/cycle_detection.opa → stdlib/core/unification/cycle_detection.opa
@@ -1,5 +1,5 @@
/*
- Copyright © 2011 MLstate
+ Copyright © 2011, 2012 MLstate
This file is part of OPA.
View
267 stdlib/unification/opatype_unification.opa → .../core/unification/opatype_unification.opa
@@ -1,5 +1,5 @@
/*
- Copyright © 2011 MLstate
+ Copyright © 2011, 2012 MLstate
This file is part of OPA.
@@ -16,6 +16,8 @@
along with OPA. If not, see <http://www.gnu.org/licenses/>.
*/
+import stdlib.core.{set, map, compare}
+
/**
* Unification of OPA runtime types
*/
@@ -25,32 +27,31 @@
*
*/
-import stdlib.core.compare
-type OpaType.Unification.push_elmt = (OpaType.ty,OpaType.ty)
-type OpaType.Unification.cycle_detector = TortoiseAndHare.t(OpaType.Unification.push_elmt)
-type OpaType.Unification.CycleDetector = CycleDetector(OpaType.Unification.cycle_detector,OpaType.Unification.push_elmt)
+type OpaTypeUnification.push_elmt = (OpaType.ty,OpaType.ty)
+type OpaTypeUnification.cycle_detector = TortoiseAndHare.t(OpaTypeUnification.push_elmt)
+type OpaTypeUnification.CycleDetector = CycleDetector(OpaTypeUnification.cycle_detector,OpaTypeUnification.push_elmt)
-type OpaType.Unification.subst = { var : stringmap(OpaType.ty)
- col : stringmap(OpaType.ty)
- row : stringmap(OpaType.ty)
- cycle_detector : OpaType.Unification.cycle_detector}
+type OpaTypeUnification.subst = { var : stringmap(OpaType.ty)
+ col : stringmap(list(OpaType.fields))
+ row : stringmap(OpaType.fields)
+ cycle_detector : OpaTypeUnification.cycle_detector}
-type OpaType.Unification.error_case = {generic : (OpaType.ty,OpaType.ty)} / {incompatible_arity} / {incompatible_quantification} / {incompatible_record}
+type OpaTypeUnification.error_case = {generic : (OpaType.ty,OpaType.ty)} / {incompatible_arity} / {incompatible_quantification} / {incompatible_record}
-type OpaType.Unification.error = list(OpaType.Unification.error_case)
+type OpaTypeUnification.error = list(OpaTypeUnification.error_case)
-type OpaType.Unification.result = outcome(OpaType.Unification.subst,OpaType.Unification.error)
+type OpaTypeUnification.result = outcome(OpaTypeUnification.subst,OpaTypeUnification.error)
/*
* This module contains private routines relative to unification
*/
@private P = {{
cmp_ty_ty(tys0,tys1) = Compare.equal_ty(tys0.f1,tys1.f1) && Compare.equal_ty(tys0.f2,tys1.f2)
- CycleDetector = TortoiseAndHare.create(cmp_ty_ty):OpaType.Unification.CycleDetector
+ CycleDetector = TortoiseAndHare.create(cmp_ty_ty):OpaTypeUnification.CycleDetector
- empty_subst = {var=StringMap.empty col=StringMap.empty row=StringMap.empty cycle_detector=CycleDetector.empty} : OpaType.Unification.subst
+ empty_subst = {var=StringMap.empty col=StringMap.empty row=StringMap.empty cycle_detector=CycleDetector.empty} : OpaTypeUnification.subst
/* substitute_var_no_opti(tv,ty,subst) get the global substitution type for tv if any or return ty (e.g. TyVar tv normally) */
substitute_var_no_opti(tv,ty,subst):OpaType.ty =
@@ -77,8 +78,8 @@ type OpaType.Unification.result = outcome(OpaType.Unification.subst,OpaType.Unif
unify_var_ty(tv,ty,subst) = {subst with var=StringMap.add(tv,ty,subst.var)}
/* and on a outcome(substitunion) and a substituion transformaion function (subst->outcome(subst)) only applied if the former is {success=...} */
- `&&&`(out_subst,unif):OpaType.Unification.result=
- match out_subst:OpaType.Unification.result
+ `&&&`(out_subst,unif):OpaTypeUnification.result=
+ match out_subst:OpaTypeUnification.result
{failure=_} as f-> f
{success=s} -> unif(s)
@@ -102,20 +103,20 @@ type OpaType.Unification.result = outcome(OpaType.Unification.subst,OpaType.Unif
check_no_substitute(ty,subst) = substitute_no_opti(ty,subst)===ty
/* unify_vars(v1,v2,t1,t2,subst) add a substitiution between v1 (<=>t1) and v2 (<=>t2), the substitution is */
- unify_vars(v1,v2,t1,t2,subst):OpaType.Unification.subst =
+ unify_vars(v1,v2,t1,t2,subst):OpaTypeUnification.subst =
if v1==v2 then subst
else if v1 < v2 then unify_var_ty(v2,t1,subst)
else unify_var_ty(v1,t2,subst)
- unifiable_list(diff_len_error, l1, l2, subst)(unifiable):OpaType.Unification.result =
+ unifiable_list(diff_len_error, l1, l2, subst)(unifiable):OpaTypeUnification.result =
if List.length(l1) != List.length(l2) then {failure=diff_len_error}
- else Fold.list2(l1,l2,{success=subst})(v1,v2,subst->
- match subst
- {failure=_} -> subst
- {success=subst} -> unifiable(v1,v2,subst)
- )
+ else List.fold2((v1,v2,subst->
+ match subst
+ {failure=_} -> subst
+ {success=subst} -> unifiable(v1,v2,subst)
+ ), l1,l2,{success=subst})
- unifiable_function_parameters(l1, l2,subst) =
+ unifiable_function_parameters(l1, l2,subst) =
unifiable_list([{incompatible_arity}],l1,l2,subst)(unifiable)
unifiable_named_parameters(l1, l2,subst) =
@@ -145,13 +146,11 @@ type OpaType.Unification.result = outcome(OpaType.Unification.subst,OpaType.Unif
tyl2 = full_expansion(ty2,n2,args2)
depth1 = List.length(tyl1)
depth2 = List.length(tyl2)
-// do println("EXPANSION\n{tyl1}\n{tyl2}\n")
tyl1 = List.drop(max(0,depth1-depth2), List.rev(tyl1))
tyl2 = List.drop(max(0,depth2-depth1), List.rev(tyl2))
-// do println("SYNC EXPANSION\n{tyl1}\n{tyl2}\n")
(tyl1,tyl2)
- unifiable_named(ty1,n1,args1,ty2,n2,args2,subst):OpaType.Unification.result =
+ unifiable_named(ty1,n1,args1,ty2,n2,args2,subst):OpaTypeUnification.result =
if n1==n2 then unifiable_named_parameters(args1,args2,subst)
else unifiable_sync_expansion(expand_and_sync(ty1,n1,args1,ty2,n2,args2),subst)
@@ -162,47 +161,66 @@ type OpaType.Unification.result = outcome(OpaType.Unification.subst,OpaType.Unif
unifiable_fields(f1:OpaType.fields, f2:OpaType.fields,subst) =
unifiable_list([{incompatible_record}],f1,f2,subst)(unifiable_field)
- unifiable_row(r1,r2,subst) =
- /* assumed sorted */
- unifiable_fields(r1, r2,subst)
+ // must be called each time a structural expansion is used for being unified
+ // if a unification requires itself than it is assumed safe
+ unifiable_with_cycle_cut(ty1,ty2,subst)=
+ old_cycle_detector = subst.cycle_detector
+ cycle_detector = CycleDetector.push((ty1,ty2),old_cycle_detector)
+ if cycle_detector.detected then {success=subst}
+ else match unifiable(ty1,ty2,{subst with ~cycle_detector})
+ {success=subst} -> {success={subst with cycle_detector=old_cycle_detector}}
+ r -> r
+ /* ************************************************************** */
+ /* GENERIC FOR LIST ********************************************* */
unifiable_list_with_vars(
cmp:'elmt,'elmt->'cmp,
- unifiable_var:'var,'elmt,'subst-> OpaType.Unification.result,
- unifiable_elmt:'elmt,'elmt,'subst -> OpaType.Unification.result,
+ unifiable_var:option('var),option('var),list('elmt),list('elmt),'subst -> OpaTypeUnification.result,
+ unifiable_elmt:'elmt,'elmt,'subst -> OpaTypeUnification.result,
v1:option('var),v2:option('var),l1:list('elmt),l2:list('elmt),subst:'subst
- ):OpaType.Unification.result =
- continue(l1,l2,subst) = unifiable_list_with_vars(cmp,unifiable_var,unifiable_elmt,v1,v2,l1,l2,subst)
- raw_unifiable_var_and_continue(v,e,nl1,nl2) =
- match v
- {some=v} -> unifiable_var(v,e,subst) &&& continue(nl1,nl2,_)
- {none} -> {failure=[]}
- end
-
- unifiable_var_and_continue(side,v1,v2,e1,e2,l1,l2,nl1,nl2) =
- if side then raw_unifiable_var_and_continue(v2,e1,nl1,l2)
- else raw_unifiable_var_and_continue(v1,e2,l1,nl2)
-
- match (l1,l2)
- ([e1|nl1],[e2|nl2]) ->
- match cmp(e1,e2)
- {eq} -> unifiable_elmt(e1,e2,subst) &&& continue(nl1,nl2,_)
- _ as neq -> unifiable_var_and_continue(neq=={lt},v1,v2,e1,e2,l1,l2,nl1,nl2)
- end
- ([],[]) -> {success=subst}
- ([e|nl1],[] as nl2)
- ([] as nl1,[e|nl2]) -> unifiable_var_and_continue(l2==[],v1,v2,e,e,l1,l2,nl1,nl2)
- end
-
-
- compare_field(f1:OpaType.field,f2:OpaType.field) = String.compare(f1.label,f2.label)
-
- unifiable_rowvar(_v,_f,_subst):OpaType.Unification.result =
- do @assert(false) // do something with v and f
- @fail("unifiable_rowvar")
-
- unifiable_row_with_vars(r1:OpaType.fields,r2:OpaType.fields,v1:option(OpaType.rowvar),v2:option(OpaType.rowvar),subst) =
- unifiable_list_with_vars(compare_field,unifiable_rowvar,unifiable_field,v1,v2,r1,r2,subst)
+ ):OpaTypeUnification.result =
+ rec append(l, a) =
+ match a with
+ | [] -> l
+ | [ta|qa] -> append(ta+>l, qa)
+ rec aux(l1, l2, a1, a2, subst:OpaTypeUnification.subst) =
+ match (l1,l2)
+ |([e1|nl1],[e2|nl2]) ->
+ match cmp(e1,e2)
+ | {eq} -> unifiable_elmt(e1, e2, subst) &&& aux(nl1, nl2, a1, a2,_)
+ | {lt} -> aux(nl1, l2, e1 +> a1, a2, subst)
+ | {gt} -> aux(l1, nl2, a1, e2 +> a2, subst)
+ end
+ |([],[]) ->
+ match (a1, a2) with
+ |([], []) -> {success=subst}
+ | _ -> unifiable_var(v1, v2, List.rev(a1), List.rev(a2), subst)
+ end
+ |(_, _) -> unifiable_var(v1, v2, append(l1, a1), append(l2, a2), subst)
+ end
+ aux(l1, l2, [], [], subst)
+
+ /* ************************************************************** */
+ /* COLVAR ******************************************************* */
+ substitute_colvar(cv, col:list(OpaType.fields), subst) =
+ match StringMap.get(cv, subst.col) with
+ | {none} -> {success = {subst with col = StringMap.add(cv, col, subst.col)}}
+ | {some = col2} -> unifiable_cols(col, col2, subst)
+
+ unifiable_colvar(c1, c2, l1:list(OpaType.fields), l2:list(OpaType.fields), subst) =
+ match (c1, c2, l1, l2) with
+ | ({none} , {none} , [] , []) -> {success = subst}
+ | ({some=v1}, {some=v2}, [], []) ->
+ if v1 == v2 then {success = subst}
+ else substitute_colvar(v1, [], subst) &&& substitute_colvar(v2, [], _)
+
+ | ({some=v}, {none}, [], l) -> substitute_colvar(v, l, subst)
+ | ({none}, {some=v}, l, []) -> substitute_colvar(v, l, subst)
+
+ | ({some=_v1}, {some=_v2}, _, _)
+ | ({none}, _, _, [_|_])
+ | (_, {none}, [_|_], _)
+ -> {failure = [{incompatible_record}]}
compare_fields(l1:OpaType.fields,l2:OpaType.fields) =
match l1
@@ -219,27 +237,47 @@ type OpaType.Unification.result = outcome(OpaType.Unification.subst,OpaType.Unif
end
end
- unifiable_colvar(_v,_r,_subst) =
- do @assert(false) // do something with v and r
- @fail("unifiable_colvar")
+ unifiable_cols_with_vars(rl1,rl2,v1,v2,subst) =
+ unifiable_list_with_vars(compare_fields,unifiable_colvar,unifiable_fields,v1,v2,rl1,rl2,subst)
- unifiable_sums_with_vars(rl1,rl2,v1,v2,subst) = unifiable_list_with_vars(compare_fields,unifiable_colvar,unifiable_fields,v1,v2,rl1,rl2,subst)
+ /* TODO - without vars */
+ unifiable_cols(c1, c2, subst) =
+ unifiable_cols_with_vars(c1, c2, none, none, subst)
- // must be called each time a structural expansion is used for being unified
- // if a unification requires itself than it is assumed safe
- unifiable_with_cycle_cut(ty1,ty2,subst)=
- old_cycle_detector = subst.cycle_detector
- cycle_detector = CycleDetector.push((ty1,ty2),old_cycle_detector)
- if cycle_detector.detected then
- do println("CUT {OpaType.to_pretty(ty1)} vs {OpaType.to_pretty(ty2)}")
- {success=subst}
- else match unifiable(ty1,ty2,{subst with ~cycle_detector})
- {success=subst} -> {success={subst with cycle_detector=old_cycle_detector}}
- r -> r
+ /* ************************************************************** */
+ /* ROWVAR ******************************************************* */
+ substitute_rowvar(rv, row, subst) =
+ match StringMap.get(rv, subst.row) with
+ | {none} -> {success = {subst with row = StringMap.add(rv, row, subst.row)}}
+ | {some = row2} -> unifiable_row(row, row2, subst)
+
+ unifiable_rowvar(v1, v2, l1, l2, subst):OpaTypeUnification.result =
+ match (v1, v2, l1, l2) with
+ | ({none} , {none} , [] , []) -> {success = subst}
+ | ({some=v1}, {some=v2}, [], []) ->
+ if v1 == v2 then {success = subst}
+ else substitute_rowvar(v1, [], subst) &&& substitute_rowvar(v2, [], _)
+
+ | ({some=_v1}, {some=_v2}, _, _)
+ | ({none}, _, [_|_], _)
+ | (_, {none}, _, [_|_])
+ -> {failure = [{incompatible_record}]}
- unifiable(ty1,ty2,subst):OpaType.Unification.result =
- do println("UNIFIABLE {OpaType.to_pretty(ty1)} vs {OpaType.to_pretty(ty2)} | {StringMap.size(subst.var)} vars")
-// if ty1===ty2 then {success=subst} else
+ | ({some=v}, {none}, l, []) -> substitute_rowvar(v, l, subst)
+ | ({none}, {some=v}, [], l) -> substitute_rowvar(v, l, subst)
+
+ compare_field(f1:OpaType.field,f2:OpaType.field) = String.ordering(f1.label,f2.label)
+
+ unifiable_row_with_vars(r1:OpaType.fields,r2:OpaType.fields,v1:option(OpaType.rowvar),v2:option(OpaType.rowvar),subst) =
+ unifiable_list_with_vars(compare_field,unifiable_rowvar,unifiable_field,v1,v2,r1,r2,subst)
+
+ unifiable_row(r1,r2,subst) =
+ unifiable_fields(r1, r2,subst)
+
+
+ /* ************************************************************** */
+ /* TYPES ******************************************************** */
+ unifiable(ty1,ty2,subst):OpaTypeUnification.result =
r = match (ty1,ty2) with
/* Named type */
({TyName_ident=n1; TyName_args=args1},{TyName_ident=n2; TyName_args=args2}) ->
@@ -260,8 +298,8 @@ type OpaType.Unification.result = outcome(OpaType.Unification.subst,OpaType.Unif
({TyVar=v1}, {TyVar=v2}) ->
// first call, something may not be completely substituted
if v1 == v2 then {success=subst} else // terminal case
- (nty1,subst) = substitute_var(v1,ty1,subst) // substitute and optimize substitution
- (nty2,subst) = substitute_var(v2,ty2,subst) // substitute and optimize substitution
+ (nty1,subst) = substitute_var(v1,ty2,subst) // substitute and optimize substitution
+ (nty2,subst) = substitute_var(v2,ty1,subst) // substitute and optimize substitution
if nty1===ty1 && nty2===ty2 then
// unify var on core tvar basis (non substituable tvar)
{success=unify_vars(v1,v2,ty1,ty2,subst)}
@@ -270,11 +308,11 @@ type OpaType.Unification.result = outcome(OpaType.Unification.subst,OpaType.Unif
unifiable(nty1,nty2,subst)
({TyVar=v1}, _) ->
- (nty1,subst) = substitute_var(v1,ty1,subst) // completely substituted
+ (nty1,subst) = substitute_var(v1,ty2,subst) // completely substituted
unifiable(nty1,ty2,subst)
(_, {TyVar=v2}) ->
- (nty2,subst) = substitute_var(v2,ty2,subst) // completely substituted
+ (nty2,subst) = substitute_var(v2,ty1,subst) // completely substituted
unifiable(ty1,nty2,subst)
@@ -285,12 +323,20 @@ type OpaType.Unification.result = outcome(OpaType.Unification.subst,OpaType.Unif
({TyRecord_row=r1 TyRecord_rowvar=v1}, {TyRecord_row=r2}) -> unifiable_row_with_vars(r1,r2,some(v1),none,subst)
({TyRecord_row=r1}, {TyRecord_row=r2 TyRecord_rowvar=v2}) -> unifiable_row_with_vars(r1,r2,none,some(v2),subst)
- /* Sums */
- ({TySum_col=l1}, {TySum_col=l2}) -> unifiable_sums_with_vars(l1,l2,none,none,subst) // TODO version without vars
+ /* Cols */
+ ({TySum_col=l1}, {TySum_col=l2}) -> unifiable_cols(l1,l2,subst)
/* ColVar */
- ({TySum_col=l1 TySum_colvar=c1}, {TySum_col=l2; TySum_colvar=c2}) -> unifiable_sums_with_vars(l1,l2,some(c1),some(c2),subst)
- ({TySum_col=l1 TySum_colvar=c1}, {TySum_col=l2}) -> unifiable_sums_with_vars(l1,l2,some(c1),none,subst)
- ({TySum_col=l1}, {TySum_col=l2 TySum_colvar=c2}) -> unifiable_sums_with_vars(l1,l2,none,some(c2),subst)
+ ({TySum_col=l1 TySum_colvar=c1}, {TySum_col=l2; TySum_colvar=c2}) -> unifiable_cols_with_vars(l1,l2,some(c1),some(c2),subst)
+ ({TySum_col=l1 TySum_colvar=c1}, {TySum_col=l2}) -> unifiable_cols_with_vars(l1,l2,some(c1),none,subst)
+ ({TySum_col=l1}, {TySum_col=l2 TySum_colvar=c2}) -> unifiable_cols_with_vars(l1,l2,none,some(c2),subst)
+ /* Rows vs Cols*/
+ ({TyRecord_row=row}, {TySum_col=col TySum_colvar=cv})
+ ({TySum_col=col TySum_colvar=cv}, {TyRecord_row=row})
+ -> unifiable_cols_with_vars([row], col, none, some(cv), subst)
+
+ ({TyRecord_row=row}, {TySum_col=col})
+ ({TySum_col=col}, {TyRecord_row=row})
+ -> unifiable_cols([row], col, subst)
/* For alls */
({TyForall_quant=q1 TyForall_body=b1}, {TyForall_quant=q2 TyForall_body=b2}) ->
@@ -303,6 +349,7 @@ type OpaType.Unification.result = outcome(OpaType.Unification.subst,OpaType.Unif
if bijection(iq1.types,iq2.types,proj_var,substitute_var_to_var(_,subst))
//&& bijection(iq1.rows,iq2.rows,substitute_row_to_row)
//&& bijection(iq1.cols,iq2.cols,substitute_col_to_col)
+ // TODO
then {success=subst}
else {failure=[{incompatible_quantification}]}
f -> f
@@ -326,7 +373,7 @@ type OpaType.Unification.result = outcome(OpaType.Unification.subst,OpaType.Unif
({TyRecord_row=_ TyRecord_rowvar=_},_)
({TySum_col=_},_)
({TySum_col=_ TySum_colvar=_},_)
- -> error("")
+ -> {failure=[]}
end
match r
@@ -345,25 +392,22 @@ type OpaType.Unification.result = outcome(OpaType.Unification.subst,OpaType.Unif
|{TyVar=v} -> v
|_ -> @fail("proj_var")
- @private substitute_col_to_col(_v1,_subst) = @fail("substitute_col_to_col") // TODO after unification
- @private substitute_row_to_row(_v1,_subst) = @fail("substitute_row_to_row") // TODO after unification
-
// we assume no subst exists between vars of l1 or vars of l2
bijection(l1,l2,proj,get_var)=
- up(l)(acc) = Fold.list(l,acc)(v1,(set1,set2)->
- match get_var(v1):option
- {some=v2} ->
- if StringSet.mem(v1,set1) || StringSet.mem(v1,set2)
- && StringSet.mem(v2,set1) || StringSet.mem(v2,set2)
- then
- set1 = StringSet.remove(v1,set1)
- set2 = StringSet.remove(v1,set2)
- set1 = StringSet.remove(v2,set1)
- set2 = StringSet.remove(v2,set2)
- (set1,set2)
- else (set1,set2)
- _ -> (set1,set2)
- )
+ up(l)(acc) = List.fold((v1,(set1,set2)->
+ match get_var(v1):option
+ {some=v2} ->
+ if StringSet.mem(v1,set1) || StringSet.mem(v1,set2)
+ && StringSet.mem(v2,set1) || StringSet.mem(v2,set2)
+ then
+ set1 = StringSet.remove(v1,set1)
+ set2 = StringSet.remove(v1,set2)
+ set1 = StringSet.remove(v2,set1)
+ set2 = StringSet.remove(v2,set2)
+ (set1,set2)
+ else (set1,set2)
+ _ -> (set1,set2)
+ ),l,acc)
l1 = List.map(proj,l1)
l2 = List.map(proj,l2)
set1 = StringSet.From.list(l1)
@@ -373,6 +417,7 @@ type OpaType.Unification.result = outcome(OpaType.Unification.subst,OpaType.Unif
}}
-Unification = {{
- unifiable(ty1,ty2) = P.unifiable(ty1,ty2,P.empty_subst)
+OpaTypeUnification = {{
+ unify(ty1, ty2) = P.unifiable(ty1, ty2, P.empty_subst)
+ is_unifiable(ty1, ty2) = Outcome.is_success(unify(ty1, ty2))
}}

0 comments on commit eedbe54

Please sign in to comment.