-
Notifications
You must be signed in to change notification settings - Fork 77
/
typing.ml
446 lines (424 loc) · 19.8 KB
/
typing.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
(* This file is part of the Catala compiler, a specification language for tax and social benefits
computation rules. Copyright (C) 2020 Inria, contributor: Denis Merigoux
<denis.merigoux@inria.fr>
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License
is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
or implied. See the License for the specific language governing permissions and limitations under
the License. *)
(** Typing for the default calculus. Because of the error terms, we perform type inference using the
classical W algorithm with union-find unification. *)
module Pos = Utils.Pos
module Errors = Utils.Errors
module A = Ast
module Cli = Utils.Cli
(** {1 Types and unification} *)
module Any =
Utils.Uid.Make
(struct
type info = unit
let format_info fmt () = Format.fprintf fmt "any"
end)
()
(** We do not reuse {!type: Dcalc.Ast.typ} because we have to include a new [TAny] variant. Indeed,
error terms can have any type and this has to be captured by the type sytem. *)
type typ =
| TLit of A.typ_lit
| TArrow of typ Pos.marked UnionFind.elem * typ Pos.marked UnionFind.elem
| TTuple of typ Pos.marked UnionFind.elem list
| TEnum of typ Pos.marked UnionFind.elem list
| TArray of typ Pos.marked UnionFind.elem
| TAny of Any.t
let typ_needs_parens (t : typ Pos.marked UnionFind.elem) : bool =
let t = UnionFind.get (UnionFind.find t) in
match Pos.unmark t with TArrow _ | TArray _ -> true | _ -> false
let rec format_typ (fmt : Format.formatter) (typ : typ Pos.marked UnionFind.elem) : unit =
let format_typ_with_parens (fmt : Format.formatter) (t : typ Pos.marked UnionFind.elem) =
if typ_needs_parens t then Format.fprintf fmt "(%a)" format_typ t
else Format.fprintf fmt "%a" format_typ t
in
let typ = UnionFind.get (UnionFind.find typ) in
match Pos.unmark typ with
| TLit l -> Format.fprintf fmt "%a" Print.format_tlit l
| TTuple ts ->
Format.fprintf fmt "(%a)"
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt " *@ ") format_typ)
ts
| TEnum ts ->
Format.fprintf fmt "(%a)"
(Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt " +@ ") format_typ)
ts
| TArrow (t1, t2) ->
Format.fprintf fmt "@[<hov 2>%a →@ %a@]" format_typ_with_parens t1 format_typ t2
| TArray t1 -> Format.fprintf fmt "@[%a@ array@]" format_typ t1
| TAny d -> Format.fprintf fmt "any[%d]" (Any.hash d)
(** Raises an error if unification cannot be performed *)
let rec unify (t1 : typ Pos.marked UnionFind.elem) (t2 : typ Pos.marked UnionFind.elem) : unit =
(* Cli.debug_print (Format.asprintf "Unifying %a and %a" format_typ t1 format_typ t2); *)
let t1_repr = UnionFind.get (UnionFind.find t1) in
let t2_repr = UnionFind.get (UnionFind.find t2) in
let repr =
match (t1_repr, t2_repr) with
| (TLit tl1, _), (TLit tl2, _) when tl1 = tl2 -> None
| (TArrow (t11, t12), _), (TArrow (t21, t22), _) ->
unify t11 t21;
unify t12 t22;
None
| (TTuple ts1, _), (TTuple ts2, _) ->
List.iter2 unify ts1 ts2;
None
| (TEnum ts1, _), (TEnum ts2, _) ->
List.iter2 unify ts1 ts2;
None
| (TArray t1', _), (TArray t2', _) ->
unify t1' t2';
None
| (TAny _, _), (TAny _, _) -> None
| (TAny _, _), t_repr | t_repr, (TAny _, _) -> Some t_repr
| (_, t1_pos), (_, t2_pos) ->
(* TODO: if we get weird error messages, then it means that we should use the persistent
version of the union-find data structure. *)
Errors.raise_multispanned_error
(Format.asprintf "Error during typechecking, types %a and %a are incompatible" format_typ
t1 format_typ t2)
[
(Some (Format.asprintf "Type %a coming from expression:" format_typ t1), t1_pos);
(Some (Format.asprintf "Type %a coming from expression:" format_typ t2), t2_pos);
]
in
let t_union = UnionFind.union t1 t2 in
match repr with None -> () | Some t_repr -> UnionFind.set t_union t_repr
(** Operators have a single type, instead of being polymorphic with constraints. This allows us to
have a simpler type system, while we argue the syntactic burden of operator annotations helps
the programmer visualize the type flow in the code. *)
let op_type (op : A.operator Pos.marked) : typ Pos.marked UnionFind.elem =
let pos = Pos.get_position op in
let bt = UnionFind.make (TLit TBool, pos) in
let it = UnionFind.make (TLit TInt, pos) in
let rt = UnionFind.make (TLit TRat, pos) in
let mt = UnionFind.make (TLit TMoney, pos) in
let dut = UnionFind.make (TLit TDuration, pos) in
let dat = UnionFind.make (TLit TDate, pos) in
let any = UnionFind.make (TAny (Any.fresh ()), pos) in
let array_any = UnionFind.make (TArray any, pos) in
let any2 = UnionFind.make (TAny (Any.fresh ()), pos) in
let arr x y = UnionFind.make (TArrow (x, y), pos) in
match Pos.unmark op with
| A.Ternop A.Fold -> arr (arr any2 (arr any any2)) (arr any2 (arr array_any any2))
| A.Binop (A.And | A.Or) -> arr bt (arr bt bt)
| A.Binop (A.Add KInt | A.Sub KInt | A.Mult KInt | A.Div KInt) -> arr it (arr it it)
| A.Binop (A.Add KRat | A.Sub KRat | A.Mult KRat | A.Div KRat) -> arr rt (arr rt rt)
| A.Binop (A.Add KMoney | A.Sub KMoney) -> arr mt (arr mt mt)
| A.Binop (A.Add KDuration | A.Sub KDuration) -> arr dut (arr dut dut)
| A.Binop (A.Sub KDate) -> arr dat (arr dat dut)
| A.Binop (A.Add KDate) -> arr dat (arr dut dat)
| A.Binop (A.Div KMoney) -> arr mt (arr mt rt)
| A.Binop (A.Mult KMoney) -> arr mt (arr rt mt)
| A.Binop (A.Lt KInt | A.Lte KInt | A.Gt KInt | A.Gte KInt) -> arr it (arr it bt)
| A.Binop (A.Lt KRat | A.Lte KRat | A.Gt KRat | A.Gte KRat) -> arr rt (arr rt bt)
| A.Binop (A.Lt KMoney | A.Lte KMoney | A.Gt KMoney | A.Gte KMoney) -> arr mt (arr mt bt)
| A.Binop (A.Lt KDate | A.Lte KDate | A.Gt KDate | A.Gte KDate) -> arr dat (arr dat bt)
| A.Binop (A.Lt KDuration | A.Lte KDuration | A.Gt KDuration | A.Gte KDuration) ->
arr dut (arr dut bt)
| A.Binop (A.Eq | A.Neq) -> arr any (arr any bt)
| A.Binop A.Map -> arr (arr any any2) (arr array_any any2)
| A.Unop (A.Minus KInt) -> arr it it
| A.Unop (A.Minus KRat) -> arr rt rt
| A.Unop (A.Minus KMoney) -> arr mt mt
| A.Unop (A.Minus KDuration) -> arr dut dut
| A.Unop A.Not -> arr bt bt
| A.Unop A.ErrorOnEmpty -> arr any any
| A.Unop (A.Log _) -> arr any any
| A.Unop A.Length -> arr array_any it
| A.Unop A.IntToRat -> arr it rt
| Binop (Mult (KDate | KDuration)) | Binop (Div (KDate | KDuration)) | Unop (Minus KDate) ->
Errors.raise_spanned_error "This operator is not available!" pos
let rec ast_to_typ (ty : A.typ) : typ =
match ty with
| A.TLit l -> TLit l
| A.TArrow (t1, t2) ->
TArrow
( UnionFind.make (Pos.map_under_mark ast_to_typ t1),
UnionFind.make (Pos.map_under_mark ast_to_typ t2) )
| A.TTuple ts -> TTuple (List.map (fun t -> UnionFind.make (Pos.map_under_mark ast_to_typ t)) ts)
| A.TEnum ts -> TEnum (List.map (fun t -> UnionFind.make (Pos.map_under_mark ast_to_typ t)) ts)
| A.TArray t -> TArray (UnionFind.make (Pos.map_under_mark ast_to_typ t))
| A.TAny -> TAny (Any.fresh ())
let rec typ_to_ast (ty : typ Pos.marked UnionFind.elem) : A.typ Pos.marked =
Pos.map_under_mark
(fun ty ->
match ty with
| TLit l -> A.TLit l
| TTuple ts -> A.TTuple (List.map typ_to_ast ts)
| TEnum ts -> A.TEnum (List.map typ_to_ast ts)
| TArrow (t1, t2) -> A.TArrow (typ_to_ast t1, typ_to_ast t2)
| TAny _ -> A.TAny
| TArray t1 -> A.TArray (typ_to_ast t1))
(UnionFind.get (UnionFind.find ty))
(** {1 Double-directed typing} *)
type env = typ Pos.marked UnionFind.elem A.VarMap.t
(** Infers the most permissive type from an expression *)
let rec typecheck_expr_bottom_up (env : env) (e : A.expr Pos.marked) : typ Pos.marked UnionFind.elem
=
let out =
match Pos.unmark e with
| EVar v -> (
match A.VarMap.find_opt (Pos.unmark v) env with
| Some t -> t
| None ->
Errors.raise_spanned_error "Variable not found in the current context"
(Pos.get_position e) )
| ELit (LBool _) -> UnionFind.make (Pos.same_pos_as (TLit TBool) e)
| ELit (LInt _) -> UnionFind.make (Pos.same_pos_as (TLit TInt) e)
| ELit (LRat _) -> UnionFind.make (Pos.same_pos_as (TLit TRat) e)
| ELit (LMoney _) -> UnionFind.make (Pos.same_pos_as (TLit TMoney) e)
| ELit (LDate _) -> UnionFind.make (Pos.same_pos_as (TLit TDate) e)
| ELit (LDuration _) -> UnionFind.make (Pos.same_pos_as (TLit TDuration) e)
| ELit LUnit -> UnionFind.make (Pos.same_pos_as (TLit TUnit) e)
| ELit LEmptyError -> UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e)
| ETuple es ->
let ts = List.map (fun (e, _) -> typecheck_expr_bottom_up env e) es in
UnionFind.make (Pos.same_pos_as (TTuple ts) e)
| ETupleAccess (e1, n, _) -> (
let t1 = typecheck_expr_bottom_up env e1 in
match Pos.unmark (UnionFind.get (UnionFind.find t1)) with
| TTuple ts -> (
match List.nth_opt ts n with
| Some t' -> t'
| None ->
Errors.raise_spanned_error
(Format.asprintf
"Expression should have a tuple type with at least %d elements but only has %d"
n (List.length ts))
(Pos.get_position e1) )
| _ ->
Errors.raise_spanned_error
(Format.asprintf "Expected a tuple, got a %a" format_typ t1)
(Pos.get_position e1) )
| EInj (e1, n, _, ts) ->
let ts = List.map (fun t -> UnionFind.make (Pos.map_under_mark ast_to_typ t)) ts in
let ts_n =
match List.nth_opt ts n with
| Some ts_n -> ts_n
| None ->
Errors.raise_spanned_error
(Format.asprintf
"Expression should have a sum type with at least %d cases but only has %d" n
(List.length ts))
(Pos.get_position e)
in
typecheck_expr_top_down env e1 ts_n;
UnionFind.make (Pos.same_pos_as (TEnum ts) e)
| EMatch (e1, es) ->
let enum_cases =
List.map (fun (e', _) -> UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e')) es
in
let t_e1 = UnionFind.make (Pos.same_pos_as (TEnum enum_cases) e1) in
typecheck_expr_top_down env e1 t_e1;
let t_ret = UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e) in
List.iteri
(fun i (es', _) ->
let enum_t = List.nth enum_cases i in
let t_es' = UnionFind.make (Pos.same_pos_as (TArrow (enum_t, t_ret)) es') in
typecheck_expr_top_down env es' t_es')
es;
t_ret
| EAbs (pos_binder, binder, taus) ->
let xs, body = Bindlib.unmbind binder in
if Array.length xs = List.length taus then
let xstaus =
List.map2
(fun x tau -> (x, UnionFind.make (ast_to_typ (Pos.unmark tau), Pos.get_position tau)))
(Array.to_list xs) taus
in
let env = List.fold_left (fun env (x, tau) -> A.VarMap.add x tau env) env xstaus in
List.fold_right
(fun (_, t_arg) (acc : typ Pos.marked UnionFind.elem) ->
UnionFind.make (TArrow (t_arg, acc), pos_binder))
xstaus
(typecheck_expr_bottom_up env body)
else
Errors.raise_spanned_error
(Format.asprintf "function has %d variables but was supplied %d types" (Array.length xs)
(List.length taus))
pos_binder
| EApp (e1, args) ->
let t_args = List.map (typecheck_expr_bottom_up env) args in
let t_ret = UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e) in
let t_app =
List.fold_right
(fun t_arg acc -> UnionFind.make (Pos.same_pos_as (TArrow (t_arg, acc)) e))
t_args t_ret
in
typecheck_expr_top_down env e1 t_app;
t_ret
| EOp op -> op_type (Pos.same_pos_as op e)
| EDefault (excepts, just, cons) ->
typecheck_expr_top_down env just (UnionFind.make (Pos.same_pos_as (TLit TBool) just));
let tcons = typecheck_expr_bottom_up env cons in
List.iter (fun except -> typecheck_expr_top_down env except tcons) excepts;
tcons
| EIfThenElse (cond, et, ef) ->
typecheck_expr_top_down env cond (UnionFind.make (Pos.same_pos_as (TLit TBool) cond));
let tt = typecheck_expr_bottom_up env et in
typecheck_expr_top_down env ef tt;
tt
| EAssert e' ->
typecheck_expr_top_down env e' (UnionFind.make (Pos.same_pos_as (TLit TBool) e'));
UnionFind.make (Pos.same_pos_as (TLit TUnit) e')
| EArray es ->
let cell_type = UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e) in
List.iter
(fun e' ->
let t_e' = typecheck_expr_bottom_up env e' in
unify cell_type t_e')
es;
UnionFind.make (Pos.same_pos_as (TArray cell_type) e)
in
(* Cli.debug_print (Format.asprintf "Found type of %a: %a" Print.format_expr e format_typ out); *)
out
(** Checks whether the expression can be typed with the provided type *)
and typecheck_expr_top_down (env : env) (e : A.expr Pos.marked)
(tau : typ Pos.marked UnionFind.elem) : unit =
(* Cli.debug_print (Format.asprintf "Typechecking %a : %a" Print.format_expr e format_typ tau); *)
match Pos.unmark e with
| EVar v -> (
match A.VarMap.find_opt (Pos.unmark v) env with
| Some tau' -> ignore (unify tau tau')
| None ->
Errors.raise_spanned_error "Variable not found in the current context"
(Pos.get_position e) )
| ELit (LBool _) -> unify tau (UnionFind.make (Pos.same_pos_as (TLit TBool) e))
| ELit (LInt _) -> unify tau (UnionFind.make (Pos.same_pos_as (TLit TInt) e))
| ELit (LRat _) -> unify tau (UnionFind.make (Pos.same_pos_as (TLit TRat) e))
| ELit (LMoney _) -> unify tau (UnionFind.make (Pos.same_pos_as (TLit TMoney) e))
| ELit (LDate _) -> unify tau (UnionFind.make (Pos.same_pos_as (TLit TDate) e))
| ELit (LDuration _) -> unify tau (UnionFind.make (Pos.same_pos_as (TLit TDuration) e))
| ELit LUnit -> unify tau (UnionFind.make (Pos.same_pos_as (TLit TUnit) e))
| ELit LEmptyError -> unify tau (UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e))
| ETuple es -> (
let tau' = UnionFind.get (UnionFind.find tau) in
match Pos.unmark tau' with
| TTuple ts -> List.iter2 (fun (e, _) t -> typecheck_expr_top_down env e t) es ts
| TAny _ ->
unify tau
(UnionFind.make
(Pos.same_pos_as
(TTuple (List.map (fun (arg, _) -> typecheck_expr_bottom_up env arg) es))
e))
| _ ->
Errors.raise_spanned_error
(Format.asprintf "expected %a, got a tuple" format_typ tau)
(Pos.get_position e) )
| ETupleAccess (e1, n, _) -> (
let t1 = typecheck_expr_bottom_up env e1 in
match Pos.unmark (UnionFind.get (UnionFind.find t1)) with
| TTuple t1s -> (
match List.nth_opt t1s n with
| Some t1n -> unify t1n tau
| None ->
Errors.raise_spanned_error
(Format.asprintf
"expression should have a tuple type with at least %d elements but only has %d" n
(List.length t1s))
(Pos.get_position e1) )
| TAny _ ->
(* Include total number of cases in ETupleAccess to continue typechecking at this point *)
Errors.raise_spanned_error
"The precise type of this expression cannot be inferred.\n\
Please raise an issue one https://github.com/CatalaLang/catala/issues"
(Pos.get_position e1)
| _ ->
Errors.raise_spanned_error
(Format.asprintf "expected a tuple , got %a" format_typ tau)
(Pos.get_position e) )
| EInj (e1, n, _, ts) ->
let ts = List.map (fun t -> UnionFind.make (Pos.map_under_mark ast_to_typ t)) ts in
let ts_n =
match List.nth_opt ts n with
| Some ts_n -> ts_n
| None ->
Errors.raise_spanned_error
(Format.asprintf
"Expression should have a sum type with at least %d cases but only has %d" n
(List.length ts))
(Pos.get_position e)
in
typecheck_expr_top_down env e1 ts_n;
unify (UnionFind.make (Pos.same_pos_as (TEnum ts) e)) tau
| EMatch (e1, es) ->
let enum_cases =
List.map (fun (e', _) -> UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e')) es
in
let t_e1 = UnionFind.make (Pos.same_pos_as (TEnum enum_cases) e1) in
typecheck_expr_top_down env e1 t_e1;
let t_ret = UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e) in
List.iteri
(fun i (es', _) ->
let enum_t = List.nth enum_cases i in
let t_es' = UnionFind.make (Pos.same_pos_as (TArrow (enum_t, t_ret)) es') in
typecheck_expr_top_down env es' t_es')
es;
unify tau t_ret
| EAbs (pos_binder, binder, t_args) ->
let xs, body = Bindlib.unmbind binder in
if Array.length xs = List.length t_args then
let xstaus =
List.map2
(fun x t_arg -> (x, UnionFind.make (Pos.map_under_mark ast_to_typ t_arg)))
(Array.to_list xs) t_args
in
let env = List.fold_left (fun env (x, t_arg) -> A.VarMap.add x t_arg env) env xstaus in
let t_out = typecheck_expr_bottom_up env body in
let t_func =
List.fold_right
(fun (_, t_arg) acc -> UnionFind.make (Pos.same_pos_as (TArrow (t_arg, acc)) e))
xstaus t_out
in
unify t_func tau
else
Errors.raise_spanned_error
(Format.asprintf "function has %d variables but was supplied %d types" (Array.length xs)
(List.length t_args))
pos_binder
| EApp (e1, args) ->
let t_args = List.map (typecheck_expr_bottom_up env) args in
let te1 = typecheck_expr_bottom_up env e1 in
let t_func =
List.fold_right
(fun t_arg acc -> UnionFind.make (Pos.same_pos_as (TArrow (t_arg, acc)) e))
t_args tau
in
unify te1 t_func
| EOp op ->
let op_typ = op_type (Pos.same_pos_as op e) in
unify op_typ tau
| EDefault (excepts, just, cons) ->
typecheck_expr_top_down env just (UnionFind.make (Pos.same_pos_as (TLit TBool) just));
typecheck_expr_top_down env cons tau;
List.iter (fun except -> typecheck_expr_top_down env except tau) excepts
| EIfThenElse (cond, et, ef) ->
typecheck_expr_top_down env cond (UnionFind.make (Pos.same_pos_as (TLit TBool) cond));
typecheck_expr_top_down env et tau;
typecheck_expr_top_down env ef tau
| EAssert e' ->
typecheck_expr_top_down env e' (UnionFind.make (Pos.same_pos_as (TLit TBool) e'));
unify tau (UnionFind.make (Pos.same_pos_as (TLit TUnit) e'))
| EArray es ->
let cell_type = UnionFind.make (Pos.same_pos_as (TAny (Any.fresh ())) e) in
List.iter
(fun e' ->
let t_e' = typecheck_expr_bottom_up env e' in
unify cell_type t_e')
es;
unify tau (UnionFind.make (Pos.same_pos_as (TArray cell_type) e))
(** {1 API} *)
(* Infer the type of an expression *)
let infer_type (e : A.expr Pos.marked) : A.typ Pos.marked =
let ty = typecheck_expr_bottom_up A.VarMap.empty e in
typ_to_ast ty
(** Typechecks an expression given an expected type *)
let check_type (e : A.expr Pos.marked) (tau : A.typ Pos.marked) =
typecheck_expr_top_down A.VarMap.empty e (UnionFind.make (Pos.map_under_mark ast_to_typ tau))