Skip to content

Commit

Permalink
Infer callee first in function call
Browse files Browse the repository at this point in the history
  The current inferrence system walks expressions from "top to bottom".
  Starting from definitions higher in the source file, and down. When a
  call is encountered, we use the information known for the callee
  definition we have at the moment it is inferred.

  This causes interesting issues in the case where the callee doesn't
  have annotations and in only partially known. For example:

  ```
  pub fn list(fuzzer: Option<a>) -> Option<List<a>> {
    inner(fuzzer, [])
  }

  fn inner(fuzzer, xs) -> Option<List<b>> {
    when fuzzer is {
      None -> Some(xs)
      Some(x) -> Some([x, ..xs])
    }
  }
  ```

  In this small program, we infer `list` first and run into `inner`.
  Yet, the arguments for `inner` are not annotated, so since we haven't
  inferred `inner` yet, we will create two unbound variables.

  And naturally, we will link the type of `[]` to being of the same type
  as `xs` -- which is still unbound at this point. The return type of
  `inner` is given by the annotation, so all-in-all, the unification
  will work without ever having to commit to a type of `[]`.

  It is only later, when `inner` is inferred, that we will generalise
  the unbound type of `xs` to a generic which the same as `b` in the
  annotation. At this point, `[]` is also typed with this same generic,
  which has a different id than `a` in `list` since it comes from
  another type definition.

  This is unfortunate and will cause issues down the line for the code
  generation. The problem doesn't occur when `inner`'s arguments are
  properly annotated or, when `inner` is actually inferred first.

  Hence, I saw two possible avenues for fixing this problem:

  1. Detect the presence of 'uncongruous generics' in definitions after
     they've all been inferred, and raise a user error asking for more
     annotations.

  2. Infer definitions in dependency order, with definitions used in
     other inferred first.

  This commit does (2) (although it may still be a good idea to do (1)
  eventually) since it offers a much better user experience. One way to
  do (2) is to construct a dependency graph between function calls, and
  ensure perform a topological sort.

  Building such graph is, however, quite tricky as it requires walking
  through the AST while maintaining scope etc. which is more-or-less
  already what the inferrence step is doing; so it feels like double
  work.

  Thus instead, this commit tries to do a deep-first inferrence and
  "pause" inferrence of definitions when encountering a call to fully
  infer the callee first. To achieve this properly, we must ensure that
  we do not infer the same definition again, so we "remember" already
  inferred definitions in the environment now.
  • Loading branch information
KtorZ authored and MicroProofs committed May 6, 2024
1 parent 3a1ab43 commit 19817b9
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 169 deletions.
15 changes: 13 additions & 2 deletions crates/aiken-lang/src/tipo/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ use super::{
use crate::{
ast::{
Annotation, CallArg, DataType, Definition, Function, ModuleConstant, ModuleKind,
RecordConstructor, RecordConstructorArg, Span, TypeAlias, TypedDefinition, TypedPattern,
UnqualifiedImport, UntypedArg, UntypedDefinition, Use, Validator, PIPE_VARIABLE,
RecordConstructor, RecordConstructorArg, Span, TypeAlias, TypedDefinition, TypedFunction,
TypedPattern, UnqualifiedImport, UntypedArg, UntypedDefinition, UntypedFunction, Use,
Validator, PIPE_VARIABLE,
},
builtins::{function, generic_var, pair, tuple, unbound_var},
tipo::{fields::FieldMap, TypeAliasAnnotation},
Expand Down Expand Up @@ -54,6 +55,12 @@ pub struct Environment<'a> {
/// Values defined in the current module (or the prelude)
pub module_values: HashMap<String, ValueConstructor>,

/// Top-level function definitions from the module
pub module_functions: HashMap<String, &'a UntypedFunction>,

/// Top-level functions that have been inferred
pub inferred_functions: HashMap<String, TypedFunction>,

previous_id: u64,

/// Values defined in the current function (or the prelude)
Expand Down Expand Up @@ -707,9 +714,11 @@ impl<'a> Environment<'a> {
previous_id: id_gen.next(),
id_gen,
ungeneralised_functions: HashSet::new(),
inferred_functions: HashMap::new(),
module_types: prelude.types.clone(),
module_types_constructors: prelude.types_constructors.clone(),
module_values: HashMap::new(),
module_functions: HashMap::new(),
imported_modules: HashMap::new(),
unused_modules: HashMap::new(),
unqualified_imported_names: HashMap::new(),
Expand Down Expand Up @@ -1201,6 +1210,8 @@ impl<'a> Environment<'a> {
&fun.location,
)?;

self.module_functions.insert(fun.name.clone(), fun);

if !fun.public {
self.init_usage(fun.name.clone(), EntityKind::PrivateFunction, fun.location);
}
Expand Down
190 changes: 166 additions & 24 deletions crates/aiken-lang/src/tipo/expr.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use super::{
environment::{assert_no_labeled_arguments, collapse_links, EntityKind, Environment},
environment::{
assert_no_labeled_arguments, collapse_links, generalise, EntityKind, Environment,
},
error::{Error, Warning},
hydrator::Hydrator,
pattern::PatternTyper,
Expand All @@ -9,11 +11,12 @@ use super::{
use crate::{
ast::{
self, Annotation, Arg, ArgName, AssignmentKind, AssignmentPattern, BinOp, Bls12_381Point,
ByteArrayFormatPreference, CallArg, ClauseGuard, Constant, Curve, IfBranch,
ByteArrayFormatPreference, CallArg, ClauseGuard, Constant, Curve, Function, IfBranch,
LogicalOpChainKind, Pattern, RecordUpdateSpread, Span, TraceKind, TraceLevel, Tracing,
TypedArg, TypedCallArg, TypedClause, TypedClauseGuard, TypedIfBranch, TypedPattern,
TypedRecordUpdateArg, UnOp, UntypedArg, UntypedAssignmentKind, UntypedClause,
UntypedClauseGuard, UntypedIfBranch, UntypedPattern, UntypedRecordUpdateArg,
UntypedClauseGuard, UntypedFunction, UntypedIfBranch, UntypedPattern,
UntypedRecordUpdateArg,
},
builtins::{
bool, byte_array, function, g1_element, g2_element, int, list, pair, string, tuple, void,
Expand All @@ -26,12 +29,126 @@ use crate::{
use std::{cmp::Ordering, collections::HashMap, ops::Deref, rc::Rc};
use vec1::Vec1;

pub(crate) fn infer_function(
fun: &UntypedFunction,
module_name: &str,
hydrators: &mut HashMap<String, Hydrator>,
environment: &mut Environment<'_>,
lines: &LineNumbers,
tracing: Tracing,
) -> Result<Function<Rc<Type>, TypedExpr, TypedArg>, Error> {
if let Some(typed_fun) = environment.inferred_functions.get(&fun.name) {
return Ok(typed_fun.clone());
};

let Function {
doc,
location,
name,
public,
arguments,
body,
return_annotation,
end_position,
can_error,
return_type: _,
} = fun;

let preregistered_fn = environment
.get_variable(name)
.expect("Could not find preregistered type for function");

let field_map = preregistered_fn.field_map().cloned();

let preregistered_type = preregistered_fn.tipo.clone();

let (args_types, return_type) = preregistered_type
.function_types()
.unwrap_or_else(|| panic!("Preregistered type for fn {name} was not a fn"));

// Infer the type using the preregistered args + return types as a starting point
let (tipo, arguments, body, safe_to_generalise) = environment.in_new_scope(|environment| {
let args = arguments
.iter()
.zip(&args_types)
.map(|(arg_name, tipo)| arg_name.to_owned().set_type(tipo.clone()))
.collect();

let hydrator = hydrators
.remove(name)
.unwrap_or_else(|| panic!("Could not find hydrator for fn {name}"));

let mut expr_typer = ExprTyper::new(environment, hydrators, lines, tracing);

expr_typer.hydrator = hydrator;

let (args, body, return_type) =
expr_typer.infer_fn_with_known_types(args, body.to_owned(), Some(return_type))?;

let args_types = args.iter().map(|a| a.tipo.clone()).collect();

let tipo = function(args_types, return_type);

let safe_to_generalise = !expr_typer.ungeneralised_function_used;

Ok::<_, Error>((tipo, args, body, safe_to_generalise))
})?;

// Assert that the inferred type matches the type of any recursive call
environment.unify(preregistered_type, tipo.clone(), *location, false)?;

// Generalise the function if safe to do so
let tipo = if safe_to_generalise {
environment.ungeneralised_functions.remove(name);

let tipo = generalise(tipo, 0);

let module_fn = ValueConstructorVariant::ModuleFn {
name: name.clone(),
field_map,
module: module_name.to_owned(),
arity: arguments.len(),
location: *location,
builtin: None,
};

environment.insert_variable(name.clone(), module_fn, tipo.clone());

tipo
} else {
tipo
};

let inferred_fn = Function {
doc: doc.clone(),
location: *location,
name: name.clone(),
public: *public,
arguments,
return_annotation: return_annotation.clone(),
return_type: tipo
.return_type()
.expect("Could not find return type for fn"),
body,
can_error: *can_error,
end_position: *end_position,
};

environment
.inferred_functions
.insert(name.to_string(), inferred_fn.clone());

Ok(inferred_fn)
}

#[derive(Debug)]
pub(crate) struct ExprTyper<'a, 'b> {
pub(crate) lines: &'a LineNumbers,

pub(crate) environment: &'a mut Environment<'b>,

pub(crate) hydrators: &'a mut HashMap<String, Hydrator>,

// We tweak the tracing behavior during type-check. Traces are either kept or left out of the
// typed AST depending on this setting.
pub(crate) tracing: Tracing,
Expand All @@ -46,6 +163,22 @@ pub(crate) struct ExprTyper<'a, 'b> {
}

impl<'a, 'b> ExprTyper<'a, 'b> {
pub fn new(
environment: &'a mut Environment<'b>,
hydrators: &'a mut HashMap<String, Hydrator>,
lines: &'a LineNumbers,
tracing: Tracing,
) -> Self {
Self {
hydrator: Hydrator::new(),
environment,
hydrators,
tracing,
ungeneralised_function_used: false,
lines,
}
}

fn check_when_exhaustiveness(
&mut self,
typed_clauses: &[TypedClause],
Expand Down Expand Up @@ -2184,17 +2317,40 @@ impl<'a, 'b> ExprTyper<'a, 'b> {
variables: self.environment.local_value_names(),
})?;

// Note whether we are using an ungeneralised function so that we can
// tell if it is safe to generalise this function after inference has
// completed.
if matches!(
&constructor.variant,
ValueConstructorVariant::ModuleFn { .. }
) {
if let ValueConstructorVariant::ModuleFn { name: fn_name, .. } =
&constructor.variant
{
// Note whether we are using an ungeneralised function so that we can
// tell if it is safe to generalise this function after inference has
// completed.
let is_ungeneralised = self.environment.ungeneralised_functions.contains(name);

self.ungeneralised_function_used =
self.ungeneralised_function_used || is_ungeneralised;

// In case we use another function, infer it first before going further.
// This ensures we have as much information possible about the function
// when we start inferring expressions using it (i.e. calls).
//
// In a way, this achieves a cheap topological processing of definitions
// where we infer used definitions first. And as a consequence, it solves
// issues where expressions would be wrongly assigned generic variables
// from other definitions.
if let Some(fun) = self.environment.module_functions.remove(fn_name) {
// NOTE: Recursive functions should not run into this multiple time.
// If we have no hydrator for this function, it means that we have already
// encountered it.
if self.hydrators.get(&fun.name).is_some() {
infer_function(
fun,
self.environment.current_module,
self.hydrators,
self.environment,
self.lines,
self.tracing,
)?;
}
}
}

// Register the value as seen for detection of unused values
Expand Down Expand Up @@ -2323,20 +2479,6 @@ impl<'a, 'b> ExprTyper<'a, 'b> {
self.environment.instantiate(t, ids, &self.hydrator)
}

pub fn new(
environment: &'a mut Environment<'b>,
lines: &'a LineNumbers,
tracing: Tracing,
) -> Self {
Self {
hydrator: Hydrator::new(),
environment,
tracing,
ungeneralised_function_used: false,
lines,
}
}

pub fn new_unbound_var(&mut self) -> Rc<Type> {
self.environment.new_unbound_var()
}
Expand Down

0 comments on commit 19817b9

Please sign in to comment.