From 1a4315b8b753e5b5f2607b036df187033d743d03 Mon Sep 17 00:00:00 2001 From: bobxli Date: Sun, 30 Apr 2023 17:06:02 +0800 Subject: [PATCH 1/3] feat: closure type parser --- src/ast/builder/llvmbuilder.rs | 2 ++ src/ast/ctx.rs | 1 + src/ast/fmt.rs | 20 ++++++++++++++-- src/ast/node/mod.rs | 1 + src/ast/node/types.rs | 43 ++++++++++++++++++++++++++++++++++ src/ast/pltype.rs | 13 ++++++++++ src/nomparser/types.rs | 22 +++++++++++++++++ 7 files changed, 100 insertions(+), 2 deletions(-) diff --git a/src/ast/builder/llvmbuilder.rs b/src/ast/builder/llvmbuilder.rs index f48d43349..2d978c192 100644 --- a/src/ast/builder/llvmbuilder.rs +++ b/src/ast/builder/llvmbuilder.rs @@ -648,6 +648,7 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> { ]; Some(self.context.struct_type(&fields, false).into()) } + PLType::Closure(_) => todo!(), // TODO } } /// # get_ret_type @@ -936,6 +937,7 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> { ); Some(tp.as_type()) } + PLType::Closure(_) => todo!(), // TODO } } diff --git a/src/ast/ctx.rs b/src/ast/ctx.rs index 0c2630d2c..d11f79c36 100644 --- a/src/ast/ctx.rs +++ b/src/ast/ctx.rs @@ -807,6 +807,7 @@ impl<'a, 'ctx> Ctx<'a> { PLType::Pointer(_) => unreachable!(), PLType::PlaceHolder(_) => CompletionItemKind::STRUCT, PLType::Union(_) => CompletionItemKind::ENUM, + PLType::Closure(_) => unreachable!(), }; if k.starts_with('|') { // skip method diff --git a/src/ast/fmt.rs b/src/ast/fmt.rs index f1ac87f76..6c6a8e1ac 100644 --- a/src/ast/fmt.rs +++ b/src/ast/fmt.rs @@ -24,8 +24,9 @@ use super::{ string_literal::StringNode, tuple::{TupleInitNode, TupleTypeNode}, types::{ - ArrayInitNode, ArrayTypeNameNode, GenericDefNode, GenericParamNode, PointerTypeNode, - StructDefNode, StructInitFieldNode, StructInitNode, TypeNameNode, TypedIdentifierNode, + ArrayInitNode, ArrayTypeNameNode, ClosureTypeNode, GenericDefNode, GenericParamNode, + PointerTypeNode, StructDefNode, StructInitFieldNode, StructInitNode, TypeNameNode, + TypedIdentifierNode, }, union::UnionDefNode, FmtTrait, NodeEnum, TypeNodeEnum, @@ -768,4 +769,19 @@ impl FmtBuilder { } self.r_paren(); } + pub fn parse_closure_type_node(&mut self, node: &ClosureTypeNode) { + self.l_paren(); + for (i, ty) in node.arg_types.iter().enumerate() { + if i > 0 { + self.comma(); + self.space(); + } + ty.format(self); + } + self.r_paren(); + self.space(); + self.token("->"); + self.space(); + node.ret_type.format(self); + } } diff --git a/src/ast/node/mod.rs b/src/ast/node/mod.rs index 5fed00349..74ed13dce 100644 --- a/src/ast/node/mod.rs +++ b/src/ast/node/mod.rs @@ -74,6 +74,7 @@ pub enum TypeNodeEnum { Pointer(PointerTypeNode), Func(FuncDefNode), Tuple(TupleTypeNode), + Closure(ClosureTypeNode), } #[enum_dispatch] pub trait TypeNode: RangeTrait + FmtTrait + PrintTrait { diff --git a/src/ast/node/types.rs b/src/ast/node/types.rs index 13e0f8a74..00d835e20 100644 --- a/src/ast/node/types.rs +++ b/src/ast/node/types.rs @@ -906,3 +906,46 @@ impl GenericParamNode { } } } + +#[node] +pub struct ClosureTypeNode { + pub arg_types: Vec>, + pub ret_type: Box, +} + +impl TypeNode for ClosureTypeNode { + fn get_type<'a, 'ctx, 'b>( + &self, + ctx: &'b mut Ctx<'a>, + builder: &'b BuilderEnum<'a, 'ctx>, + ) -> TypeNodeResult { + todo!() + } + + fn emit_highlight(&self, ctx: &mut Ctx) { + todo!() + } + + fn eq_or_infer<'a, 'ctx, 'b>( + &self, + ctx: &'b mut Ctx<'a>, + pltype: Arc>, + builder: &'b BuilderEnum<'a, 'ctx>, + ) -> Result { + todo!() + } +} + +impl PrintTrait for ClosureTypeNode { + fn print(&self, tabs: usize, end: bool, mut line: Vec) { + deal_line(tabs, &mut line, end); + tab(tabs, line.clone(), end); + println!("ClosureTypeNode"); + let mut i = self.arg_types.len(); + for g in &self.arg_types { + i -= 1; + g.print(tabs + 1, i == 0, line.clone()); + } + self.ret_type.print(tabs + 1, true, line.clone()); + } +} diff --git a/src/ast/pltype.rs b/src/ast/pltype.rs index 8e8e81b99..1517701e1 100644 --- a/src/ast/pltype.rs +++ b/src/ast/pltype.rs @@ -59,7 +59,12 @@ pub enum PLType { PlaceHolder(PlaceHolderType), Trait(STType), Union(UnionType), + Closure(ClosureType), } + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ClosureType {} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct UnionType { pub name: String, @@ -276,6 +281,7 @@ impl PLType { PLType::Generic(_) => "generic".to_string(), PLType::Trait(_) => "trait".to_string(), PLType::Union(_) => "union".to_string(), + PLType::Closure(_) => "closure".to_string(), } } pub fn get_typenode(&self) -> Box { @@ -300,6 +306,7 @@ impl PLType { PLType::Trait(t) => new_typename_node(&t.name, t.range), PLType::Fn(_) => unreachable!(), PLType::Union(u) => new_typename_node(&u.name, u.range), + PLType::Closure(_) => todo!(), // TODO } } pub fn is(&self, pri_type: &PriType) -> bool { @@ -320,6 +327,7 @@ impl PLType { PLType::Pointer(_) => (), PLType::Generic(g) => f_local(g), PLType::PlaceHolder(_) => (), + PLType::Closure(_) => (), } } @@ -343,6 +351,7 @@ impl PLType { PLType::PlaceHolder(p) => p.name.clone(), PLType::Trait(t) => t.name.clone(), PLType::Union(u) => u.name.clone(), + PLType::Closure(_) => todo!(), // TODO } } pub fn get_llvm_name(&self) -> String { @@ -365,6 +374,7 @@ impl PLType { } PLType::PlaceHolder(p) => p.get_place_holder_name(), PLType::Union(u) => u.name.clone(), + PLType::Closure(_) => todo!(), // TODO } } @@ -386,6 +396,7 @@ impl PLType { PLType::Pointer(p) => p.borrow().get_full_elm_name(), PLType::PlaceHolder(p) => p.name.clone(), PLType::Union(u) => u.get_full_name(), + PLType::Closure(_) => todo!(), // TODO } } pub fn get_full_elm_name_without_generic(&self) -> String { @@ -406,6 +417,7 @@ impl PLType { PLType::Pointer(p) => p.borrow().get_full_elm_name(), PLType::PlaceHolder(p) => p.name.clone(), PLType::Union(u) => u.get_full_name_except_generic(), + PLType::Closure(_) => todo!(), //TODO } } pub fn get_ptr_depth(&self) -> usize { @@ -484,6 +496,7 @@ impl PLType { PLType::PlaceHolder(p) => Some(p.range), PLType::Trait(t) => Some(t.range), PLType::Union(u) => Some(u.range), + PLType::Closure(_) => None, } } diff --git a/src/nomparser/types.rs b/src/nomparser/types.rs index 1ea1f0b7d..b1e4b49e8 100644 --- a/src/nomparser/types.rs +++ b/src/nomparser/types.rs @@ -1,5 +1,6 @@ use crate::ast::node::interface::{MultiTraitNode, TraitBoundNode}; use crate::ast::node::tuple::TupleTypeNode; +use crate::ast::node::types::ClosureTypeNode; use crate::nomparser::Span; use crate::{ ast::node::types::{ArrayTypeNameNode, TypeNameNode}, @@ -259,3 +260,24 @@ fn tuple_type(input: Span) -> IResult> { }, )(input) } + +#[test_parser("(i32,i64,(i32,i64)) => i32")] +fn closure_type(input: Span) -> IResult> { + map_res( + tuple((tuple_type, tag_token_symbol_ex(TokenType::ARROW), type_name)), + |(params, _, ret)| { + let range = params.range().start.to(ret.range().end); + match params.as_ref() { + TypeNodeEnum::Tuple(t) => { + let node = Box::new(TypeNodeEnum::Closure(ClosureTypeNode { + arg_types: t.tps.to_owned(), + ret_type: ret, + range, + })); + res_box(node) + } + _ => unreachable!(), + } + }, + )(input) +} From 760f6e4d7ce1619b9a7cecb5d66891ff23ab6d3a Mon Sep 17 00:00:00 2001 From: bobxli Date: Mon, 1 May 2023 15:16:47 +0800 Subject: [PATCH 2/3] feat: done add fntype support --- src/ast/builder/llvmbuilder.rs | 48 +- src/ast/ctx.rs | 57 +- src/ast/diag.rs | 6 +- src/ast/fmt.rs | 2 +- src/ast/node/function.rs | 241 ++++-- src/ast/node/primary.rs | 4 + src/ast/node/ret.rs | 2 +- src/ast/node/types.rs | 24 +- src/ast/pltype.rs | 52 +- src/ast/test.rs | 1287 ++++++++++++++++---------------- src/nomparser/types.rs | 2 +- test/fmt/test_fmt.pi | 31 + test/main.pi | 2 + test/test/fntype.pi | 32 + 14 files changed, 1025 insertions(+), 765 deletions(-) create mode 100644 test/test/fntype.pi diff --git a/src/ast/builder/llvmbuilder.rs b/src/ast/builder/llvmbuilder.rs index 2d978c192..d41126e2d 100644 --- a/src/ast/builder/llvmbuilder.rs +++ b/src/ast/builder/llvmbuilder.rs @@ -31,7 +31,7 @@ use inkwell::{ }; use rustc_hash::FxHashMap; -use crate::ast::{diag::PLDiag, pass::run_immix_pass}; +use crate::ast::{diag::PLDiag, pass::run_immix_pass, pltype::ClosureType}; use super::{ super::{ @@ -587,6 +587,22 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> { fn_type }) } + + fn get_closure_fn_type(&self, closure: &ClosureType, ctx: &mut Ctx<'a>) -> FunctionType<'ctx> { + let params = closure + .arg_types + .iter() + .map(|pltype| { + let tp = self.get_basic_type_op(&pltype.borrow(), ctx).unwrap(); + let tp: BasicMetadataTypeEnum = tp.into(); + tp + }) + .collect::>(); + let fn_type = self + .get_ret_type(&closure.ret_type.borrow(), ctx) + .fn_type(¶ms, false); + fn_type + } /// # get_basic_type_op /// get the basic type of the type /// used in code generation @@ -648,7 +664,19 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> { ]; Some(self.context.struct_type(&fields, false).into()) } - PLType::Closure(_) => todo!(), // TODO + PLType::Closure(c) => { + // all closures are represented as a struct with a function pointer and an i8ptr(point to closure data) + let fields = vec![ + self.get_closure_fn_type(c, ctx) + .ptr_type(AddressSpace::default()) + .into(), + self.context + .i8_type() + .ptr_type(AddressSpace::default()) + .into(), + ]; + Some(self.context.struct_type(&fields, false).into()) + } } } /// # get_ret_type @@ -937,7 +965,7 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> { ); Some(tp.as_type()) } - PLType::Closure(_) => todo!(), // TODO + PLType::Closure(_) => self.get_ditype(&PLType::Primitive(PriType::I64), ctx), // TODO } } @@ -1124,6 +1152,9 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { | AnyValueEnum::PointerValue(_) | AnyValueEnum::StructValue(_) | AnyValueEnum::VectorValue(_) => handle, + AnyValueEnum::FunctionValue(f) => { + return Ok(self.get_llvm_value_handle(&f.as_global_value().as_any_value_enum())); + } _ => return Err(ctx.add_diag(range.new_err(ErrorCode::EXPECT_VALUE))), }) } else { @@ -1253,8 +1284,15 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> { let value = self.get_llvm_value(value).unwrap(); let ptr = self.get_llvm_value(ptr).unwrap(); let ptr = ptr.into_pointer_value(); - self.builder - .build_store::(ptr, value.try_into().unwrap()); + let value = if value.is_function_value() { + value + .into_function_value() + .as_global_value() + .as_basic_value_enum() + } else { + value.try_into().unwrap() + }; + self.builder.build_store(ptr, value); } fn build_const_in_bounds_gep( &self, diff --git a/src/ast/ctx.rs b/src/ast/ctx.rs index d11f79c36..c067adf33 100644 --- a/src/ast/ctx.rs +++ b/src/ast/ctx.rs @@ -214,14 +214,29 @@ impl<'a, 'ctx> Ctx<'a> { } pub fn up_cast<'b>( &mut self, - trait_pltype: Arc>, - st_pltype: Arc>, - trait_range: Range, - st_range: Range, - st_value: usize, + target_pltype: Arc>, + ori_pltype: Arc>, + target_range: Range, + ori_range: Range, + ori_value: usize, builder: &'b BuilderEnum<'a, 'ctx>, ) -> Result { - if let PLType::Union(u) = &*trait_pltype.borrow() { + if let PLType::Closure(_) = &*target_pltype.borrow() { + if ori_value == usize::MAX { + return Err(ori_range + .new_err(ErrorCode::CANNOT_ASSIGN_INCOMPLETE_GENERICS) + .add_help("try add generic type explicitly to fix this error.") + .add_to_ctx(self)); + } + let closure_v = builder.alloc("tmp", &target_pltype.borrow(), self, None); + let closure_f = builder.build_struct_gep(closure_v, 0, "closure_f").unwrap(); + let ori_value = builder.try_load2var(ori_range, ori_value, self)?; + // TODO now, we only handle the case that the closure is a pure function. + // TODO the real closure case is leave to the future. + builder.build_store(closure_f, ori_value); + return Ok(closure_v); + } + if let PLType::Union(u) = &*target_pltype.borrow() { let union_members = self.run_in_type_mod(u, |ctx, u| { let mut union_members = vec![]; for tp in &u.sum_types { @@ -231,9 +246,9 @@ impl<'a, 'ctx> Ctx<'a> { Ok(union_members) })?; for (i, tp) in union_members.iter().enumerate() { - if *tp.borrow() == *st_pltype.borrow() { + if *tp.borrow() == *ori_pltype.borrow() { let union_handle = - builder.alloc("tmp_unionv", &trait_pltype.borrow(), self, None); + builder.alloc("tmp_unionv", &target_pltype.borrow(), self, None); let union_value = builder .build_struct_gep(union_handle, 1, "union_value") .unwrap(); @@ -242,11 +257,11 @@ impl<'a, 'ctx> Ctx<'a> { .unwrap(); let union_type = builder.int_value(&PriType::U64, i as u64, false); builder.build_store(union_type_field, union_type); - let mut ptr = st_value; - if !builder.is_ptr(st_value) { + let mut ptr = ori_value; + if !builder.is_ptr(ori_value) { // mv to heap - ptr = builder.alloc("tmp", &st_pltype.borrow(), self, None); - builder.build_store(ptr, st_value); + ptr = builder.alloc("tmp", &ori_pltype.borrow(), self, None); + builder.build_store(ptr, ori_value); } let st_value = builder.bitcast( self, @@ -260,20 +275,20 @@ impl<'a, 'ctx> Ctx<'a> { } } } - let (st_pltype, st_value) = self.auto_deref(st_pltype, st_value, builder); + let (st_pltype, st_value) = self.auto_deref(ori_pltype, ori_value, builder); if let (PLType::Trait(t), PLType::Struct(st)) = - (&*trait_pltype.borrow(), &*st_pltype.borrow()) + (&*target_pltype.borrow(), &*st_pltype.borrow()) { if !st.implements_trait(t, &self.plmod) { return Err(mismatch_err!( self, - st_range, - trait_range, - trait_pltype.borrow(), + ori_range, + target_range, + target_pltype.borrow(), st_pltype.borrow() )); } - let trait_handle = builder.alloc("tmp_traitv", &trait_pltype.borrow(), self, None); + let trait_handle = builder.alloc("tmp_traitv", &target_pltype.borrow(), self, None); for f in t.list_trait_fields().iter() { let mthd = st.find_method(self, &f.name).unwrap(); let fnhandle = builder.get_or_insert_fn_handle(&mthd, self); @@ -303,9 +318,9 @@ impl<'a, 'ctx> Ctx<'a> { #[allow(clippy::needless_return)] return Err(mismatch_err!( self, - st_range, - trait_range, - trait_pltype.borrow(), + ori_range, + target_range, + target_pltype.borrow(), st_pltype.borrow() )); } diff --git a/src/ast/diag.rs b/src/ast/diag.rs index f3fe6c3c4..32e032fcc 100644 --- a/src/ast/diag.rs +++ b/src/ast/diag.rs @@ -120,7 +120,8 @@ define_error!( INVALID_IS_EXPR = "invalid `is` expression", INVALID_CAST = "invalid cast", METHOD_NOT_FOUND = "method not found", - DERIVE_TRAIT_NOT_IMPL = "derive trait not impl" + DERIVE_TRAIT_NOT_IMPL = "derive trait not impl", + CANNOT_ASSIGN_INCOMPLETE_GENERICS = "cannot assign incomplete generic function to variable", ); macro_rules! define_warn { ($( @@ -397,6 +398,9 @@ impl PLDiag { file: String, txt: Option<(String, Vec)>, ) -> &mut Self { + if range == Default::default() { + return self; + } self.raw.labels.push(PLLabel { file, txt, range }); self } diff --git a/src/ast/fmt.rs b/src/ast/fmt.rs index 6c6a8e1ac..90a93c787 100644 --- a/src/ast/fmt.rs +++ b/src/ast/fmt.rs @@ -780,7 +780,7 @@ impl FmtBuilder { } self.r_paren(); self.space(); - self.token("->"); + self.token("=>"); self.space(); node.ret_type.format(self); } diff --git a/src/ast/node/function.rs b/src/ast/node/function.rs index 59a32a5ad..fe69c39a3 100644 --- a/src/ast/node/function.rs +++ b/src/ast/node/function.rs @@ -3,10 +3,11 @@ use super::node_result::NodeResultBuilder; use super::statement::StatementsNode; use super::*; use super::{types::TypedIdentifierNode, Node, TypeNode}; +use crate::ast::builder::ValueHandle; use crate::ast::diag::ErrorCode; use crate::ast::node::{deal_line, tab}; -use crate::ast::pltype::{get_type_deep, FNValue, FnType, PLType}; +use crate::ast::pltype::{get_type_deep, ClosureType, FNValue, FnType, PLType}; use crate::ast::tokens::TokenType; use indexmap::IndexMap; use internal_macro::node; @@ -35,6 +36,102 @@ impl PrintTrait for FuncCallNode { } } +impl FuncCallNode { + fn handle_closure_call<'a, 'ctx, 'b>( + &mut self, + ctx: &'b mut Ctx<'a>, + builder: &'b BuilderEnum<'a, 'ctx>, + c: &ClosureType, + v: ValueHandle, + ) -> NodeResult { + // TODO we only handle the case that the closure is a pure function + // TODO the real closure case is leave to the future + let mut para_values = vec![]; + let mut value_pltypes = vec![]; + if self.paralist.len() != c.arg_types.len() { + return Err(self + .range + .new_err(ErrorCode::PARAMETER_LENGTH_NOT_MATCH) + .add_to_ctx(ctx)); + } + for (i, para) in self.paralist.iter_mut().enumerate() { + let pararange = para.range(); + let v = ctx + .emit_with_expectation( + para, + c.arg_types[i].clone(), + c.arg_types[i].borrow().get_range().unwrap_or_default(), + builder, + )? + .get_value(); + if v.is_none() { + return Ok(Default::default()); + } + let v = v.unwrap(); + let value_pltype = v.get_ty(); + let value_pltype = get_type_deep(value_pltype); + let load = ctx.try_load2var(pararange, v.get_value(), builder)?; + para_values.push(load); + value_pltypes.push((value_pltype, pararange)); + } + let re = builder.build_struct_gep(v, 0, "real_fn").unwrap(); + let re = builder.build_load(re, "real_fn"); + let ret = builder.build_call(re, ¶_values, &c.ret_type.borrow(), ctx); + handle_ret(ret, c.ret_type.clone()) + } + + fn build_params<'a, 'ctx, 'b>( + &mut self, + ctx: &'b mut Ctx<'a>, + builder: &'b BuilderEnum<'a, 'ctx>, + para_values: &mut Vec, + value_pltypes: &mut Vec<(Arc>, Range)>, + ) -> Result<(), PLDiag> { + for para in self.paralist.iter_mut() { + let pararange = para.range(); + let v = para.emit(ctx, builder)?.get_value(); + if v.is_none() { + return Ok(()); + } + let v = v.unwrap(); + let value_pltype = v.get_ty(); + let value_pltype = get_type_deep(value_pltype); + let load = ctx.try_load2var(pararange, v.get_value(), builder)?; + para_values.push(load); + value_pltypes.push((value_pltype, pararange)); + } + Ok(()) + } + + fn build_hint(&mut self, ctx: &mut Ctx, fnvalue: &FNValue, skip: u32) { + for (i, para) in self.paralist.iter_mut().enumerate() { + let pararange = para.range(); + ctx.push_param_hint(pararange, fnvalue.param_names[i + skip as usize].clone()); + ctx.set_if_sig( + para.range(), + fnvalue.name.clone().split("::").last().unwrap().to_string() + + "(" + + fnvalue + .param_names + .iter() + .enumerate() + .map(|(i, s)| { + s.clone() + + ": " + + FmtBuilder::generate_node(&fnvalue.fntype.param_pltypes[i]) + .as_str() + }) + .collect::>() + .join(", ") + .as_str() + + ")", + &fnvalue.param_names, + i as u32 + skip, + ); + } + } +} + impl Node for FuncCallNode { fn emit<'a, 'ctx, 'b>( &mut self, @@ -55,6 +152,9 @@ impl Node for FuncCallNode { res.fntype = res.fntype.new_pltype(); res } + PLType::Closure(c) => { + return self.handle_closure_call(ctx, builder, c, v.get_value()); + } _ => return Err(ctx.add_diag(self.range.new_err(ErrorCode::FUNCTION_NOT_FOUND))), }; @@ -94,47 +194,14 @@ impl Node for FuncCallNode { } let fn_handle = v.get_value(); if fnvalue.fntype.param_pltypes.len() - skip as usize != self.paralist.len() { - return Err(ctx.add_diag(self.range.new_err(ErrorCode::PARAMETER_LENGTH_NOT_MATCH))); - } - for (i, para) in self.paralist.iter_mut().enumerate() { - let pararange = para.range(); - ctx.push_param_hint(pararange, fnvalue.param_names[i + skip as usize].clone()); - ctx.set_if_sig( - para.range(), - fnvalue.name.clone().split("::").last().unwrap().to_string() - + "(" - + fnvalue - .param_names - .iter() - .enumerate() - .map(|(i, s)| { - s.clone() - + ": " - + FmtBuilder::generate_node(&fnvalue.fntype.param_pltypes[i]) - .as_str() - }) - .collect::>() - .join(", ") - .as_str() - + ")", - &fnvalue.param_names, - i as u32 + skip, - ); + return Err(self + .range + .new_err(ErrorCode::PARAMETER_LENGTH_NOT_MATCH) + .add_to_ctx(ctx)); } + self.build_hint(ctx, &fnvalue, skip); let mut value_pltypes = vec![]; - for para in self.paralist.iter_mut() { - let pararange = para.range(); - let v = para.emit(ctx, builder)?.get_value(); - if v.is_none() { - return Ok(Default::default()); - } - let v = v.unwrap(); - let value_pltype = v.get_ty(); - let value_pltype = get_type_deep(value_pltype); - let load = ctx.try_load2var(pararange, v.get_value(), builder)?; - para_values.push(load); - value_pltypes.push((value_pltype, pararange)); - } + self.build_params(ctx, builder, &mut para_values, &mut value_pltypes)?; // value check and generic infer let res = ctx.protect_generic_context(&fnvalue.fntype.generic_map.clone(), |ctx| { ctx.run_in_type_mod_mut(&mut fnvalue, |ctx, fnvalue| { @@ -148,36 +215,14 @@ impl Node for FuncCallNode { ); } } - for (i, (value_pltype, pararange)) in value_pltypes.iter().enumerate() { - let eqres = fnvalue.fntype.param_pltypes[i + skip as usize].eq_or_infer( - ctx, - value_pltype.clone(), - builder, - )?; - if !eqres.eq { - return Err( - ctx.add_diag(pararange.new_err(ErrorCode::PARAMETER_TYPE_NOT_MATCH)) - ); - } - if eqres.need_up_cast { - let mut value = para_values[i + skip as usize]; - let ptr2v = - builder.alloc("tmp_up_cast_ptr", &value_pltype.borrow(), ctx, None); - builder.build_store(ptr2v, value); - let trait_pltype = fnvalue.fntype.param_pltypes[i + skip as usize] - .get_type(ctx, builder)?; - value = ctx.up_cast( - trait_pltype, - value_pltype.clone(), - fnvalue.fntype.param_pltypes[i + skip as usize].range(), - *pararange, - ptr2v, - builder, - )?; - value = ctx.try_load2var(*pararange, value, builder)?; - para_values[i + skip as usize] = value; - } - } + check_and_cast_params( + &value_pltypes, + &fnvalue.fntype.param_pltypes, + skip, + ctx, + builder, + &mut para_values, + )?; Ok(()) })?; if !fnvalue.fntype.generic_map.is_empty() { @@ -202,21 +247,59 @@ impl Node for FuncCallNode { })?; let ret = builder.build_call(function, ¶_values, &rettp.borrow(), ctx); ctx.save_if_comment_doc_hover(id_range, Some(fnvalue.doc.clone())); - match ret { - Some(v) => v - .new_output(match &*rettp.clone().borrow() { - PLType::Generic(g) => g.curpltype.as_ref().unwrap().clone(), - _ => rettp, - }) - .to_result(), - None => usize::MAX.new_output(rettp).to_result(), - } + handle_ret(ret, rettp) }); ctx.set_if_refs_tp(pltype, id_range); ctx.emit_comment_highlight(&self.comments[0]); res } } + +fn check_and_cast_params<'a, 'ctx, 'b>( + value_pltypes: &[(Arc>, Range)], + param_types: &[Box], + skip: u32, + ctx: &'b mut Ctx<'a>, + builder: &'b BuilderEnum<'a, 'ctx>, + para_values: &mut [usize], +) -> Result<(), PLDiag> { + for (i, (value_pltype, pararange)) in value_pltypes.iter().enumerate() { + let eqres = + param_types[i + skip as usize].eq_or_infer(ctx, value_pltype.clone(), builder)?; + if !eqres.eq { + return Err(ctx.add_diag(pararange.new_err(ErrorCode::PARAMETER_TYPE_NOT_MATCH))); + } + if eqres.need_up_cast { + let mut value = para_values[i + skip as usize]; + let ptr2v = builder.alloc("tmp_up_cast_ptr", &value_pltype.borrow(), ctx, None); + builder.build_store(ptr2v, value); + let trait_pltype = param_types[i + skip as usize].get_type(ctx, builder)?; + value = ctx.up_cast( + trait_pltype, + value_pltype.clone(), + param_types[i + skip as usize].range(), + *pararange, + ptr2v, + builder, + )?; + value = ctx.try_load2var(*pararange, value, builder)?; + para_values[i + skip as usize] = value; + } + } + Ok(()) +} + +fn handle_ret(ret: Option, rettp: Arc>) -> Result { + match ret { + Some(v) => v + .new_output(match &*rettp.clone().borrow() { + PLType::Generic(g) => g.curpltype.as_ref().unwrap().clone(), + _ => rettp, + }) + .to_result(), + None => usize::MAX.new_output(rettp).to_result(), + } +} #[node] pub struct FuncDefNode { pub id: Box, diff --git a/src/ast/node/primary.rs b/src/ast/node/primary.rs index 5008e9828..f98b278dc 100644 --- a/src/ast/node/primary.rs +++ b/src/ast/node/primary.rs @@ -167,6 +167,10 @@ impl Node for VarNode { PLType::Fn(f) => { ctx.send_if_go_to_def(self.range, f.range, f.path.clone()); ctx.push_semantic_token(self.range, SemanticTokenType::FUNCTION, 0); + if !f.fntype.generic { + let handle = builder.get_or_insert_fn_handle(f, ctx); + return handle.new_output(tp.clone()).to_result(); + } return usize::MAX.new_output(tp.clone()).to_result(); } _ => return Err(ctx.add_diag(self.range.new_err(ErrorCode::VAR_NOT_FOUND))), diff --git a/src/ast/node/ret.rs b/src/ast/node/ret.rs index 2c1ffffe4..c7969be83 100644 --- a/src/ast/node/ret.rs +++ b/src/ast/node/ret.rs @@ -30,7 +30,7 @@ impl Node for RetNode { ) -> NodeResult { let ret_pltype = ctx.rettp.as_ref().unwrap().clone(); if let Some(ret_node) = &mut self.value { - // let (value, value_pltype, _) = ctx.emit_with_expectation(ret_node, Some(ret_pltype.clone()), ret_pltype.borrow().get_range().unwrap_or_default(), builder)?; + // TODO implicit cast && type infer let v = ret_node.emit(ctx, builder)?.get_value().unwrap(); ctx.emit_comment_highlight(&self.comments[0]); let value_pltype = v.get_ty(); diff --git a/src/ast/node/types.rs b/src/ast/node/types.rs index 00d835e20..350b7e963 100644 --- a/src/ast/node/types.rs +++ b/src/ast/node/types.rs @@ -15,6 +15,7 @@ use crate::ast::diag::ErrorCode; use crate::ast::plmod::MutVec; use crate::ast::pltype::get_type_deep; +use crate::ast::pltype::ClosureType; use crate::ast::pltype::{ARRType, Field, GenericType, PLType, STType}; use crate::ast::tokens::TokenType; use indexmap::IndexMap; @@ -919,11 +920,23 @@ impl TypeNode for ClosureTypeNode { ctx: &'b mut Ctx<'a>, builder: &'b BuilderEnum<'a, 'ctx>, ) -> TypeNodeResult { - todo!() + let mut arg_types = vec![]; + for g in self.arg_types.iter() { + arg_types.push(g.get_type(ctx, builder)?); + } + let ret_type = self.ret_type.get_type(ctx, builder)?; + Ok(Arc::new(RefCell::new(PLType::Closure(ClosureType { + arg_types, + ret_type, + range: self.range, + })))) } fn emit_highlight(&self, ctx: &mut Ctx) { - todo!() + for g in self.arg_types.iter() { + g.emit_highlight(ctx); + } + self.ret_type.emit_highlight(ctx); } fn eq_or_infer<'a, 'ctx, 'b>( @@ -932,7 +945,12 @@ impl TypeNode for ClosureTypeNode { pltype: Arc>, builder: &'b BuilderEnum<'a, 'ctx>, ) -> Result { - todo!() + let left = self.get_type(ctx, builder)?; + let eq = *left.borrow() == *pltype.borrow(); + Ok(crate::ast::ctx::EqRes { + eq, + need_up_cast: false, + }) } } diff --git a/src/ast/pltype.rs b/src/ast/pltype.rs index 1517701e1..7f027f587 100644 --- a/src/ast/pltype.rs +++ b/src/ast/pltype.rs @@ -1,5 +1,6 @@ use super::ctx::Ctx; use super::diag::ErrorCode; +use super::node::types::ClosureTypeNode; use super::plmod::Mod; use super::plmod::MutVec; use super::tokens::TokenType; @@ -62,8 +63,43 @@ pub enum PLType { Closure(ClosureType), } -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ClosureType {} +#[derive(Debug, Clone, Eq)] +pub struct ClosureType { + pub arg_types: Vec>>, + pub ret_type: Arc>, + pub range: Range, +} + +impl PartialEq for ClosureType { + fn eq(&self, other: &Self) -> bool { + self.arg_types == other.arg_types && self.ret_type == other.ret_type + } +} + +impl ClosureType { + pub fn to_type_node(&self) -> TypeNodeEnum { + TypeNodeEnum::Closure(ClosureTypeNode { + arg_types: self + .arg_types + .iter() + .map(|t| t.borrow().get_typenode()) + .collect(), + ret_type: self.ret_type.borrow().get_typenode(), + range: self.range, + }) + } + pub fn get_name(&self) -> String { + format!( + "({}) => {}", + self.arg_types + .iter() + .map(|t| t.borrow().get_name()) + .collect::>() + .join(", "), + self.ret_type.borrow().get_name() + ) + } +} #[derive(Debug, Clone, PartialEq, Eq)] pub struct UnionType { @@ -306,7 +342,7 @@ impl PLType { PLType::Trait(t) => new_typename_node(&t.name, t.range), PLType::Fn(_) => unreachable!(), PLType::Union(u) => new_typename_node(&u.name, u.range), - PLType::Closure(_) => todo!(), // TODO + PLType::Closure(c) => Box::new(c.to_type_node()), } } pub fn is(&self, pri_type: &PriType) -> bool { @@ -351,7 +387,7 @@ impl PLType { PLType::PlaceHolder(p) => p.name.clone(), PLType::Trait(t) => t.name.clone(), PLType::Union(u) => u.name.clone(), - PLType::Closure(_) => todo!(), // TODO + PLType::Closure(c) => c.get_name(), } } pub fn get_llvm_name(&self) -> String { @@ -374,7 +410,7 @@ impl PLType { } PLType::PlaceHolder(p) => p.get_place_holder_name(), PLType::Union(u) => u.name.clone(), - PLType::Closure(_) => todo!(), // TODO + PLType::Closure(c) => c.get_name(), } } @@ -396,7 +432,7 @@ impl PLType { PLType::Pointer(p) => p.borrow().get_full_elm_name(), PLType::PlaceHolder(p) => p.name.clone(), PLType::Union(u) => u.get_full_name(), - PLType::Closure(_) => todo!(), // TODO + PLType::Closure(c) => c.get_name(), } } pub fn get_full_elm_name_without_generic(&self) -> String { @@ -417,7 +453,7 @@ impl PLType { PLType::Pointer(p) => p.borrow().get_full_elm_name(), PLType::PlaceHolder(p) => p.name.clone(), PLType::Union(u) => u.get_full_name_except_generic(), - PLType::Closure(_) => todo!(), //TODO + PLType::Closure(c) => c.get_name(), } } pub fn get_ptr_depth(&self) -> usize { @@ -496,7 +532,7 @@ impl PLType { PLType::PlaceHolder(p) => Some(p.range), PLType::Trait(t) => Some(t.range), PLType::Union(u) => Some(u.range), - PLType::Closure(_) => None, + PLType::Closure(c) => Some(c.range), } } diff --git a/src/ast/test.rs b/src/ast/test.rs index 06d148ec8..a3b5ec2b8 100644 --- a/src/ast/test.rs +++ b/src/ast/test.rs @@ -1,681 +1,678 @@ -#[cfg(test)] -mod test { - use std::{ - cell::RefCell, - fs::remove_file, - sync::{Arc, Mutex}, - }; +#![cfg(test)] +use std::{ + cell::RefCell, + fs::remove_file, + sync::{Arc, Mutex}, +}; - use lsp_types::{ - CompletionItemKind, GotoDefinitionResponse, HoverContents, InlayHintLabel, MarkedString, - }; - use salsa::{accumulator::Accumulator, storage::HasJar}; +use lsp_types::{ + CompletionItemKind, GotoDefinitionResponse, HoverContents, InlayHintLabel, MarkedString, +}; +use salsa::{accumulator::Accumulator, storage::HasJar}; - use crate::{ - ast::{ - accumulators::{ - Completions, Diagnostics, DocSymbols, GotoDef, Hints, PLFormat, PLHover, - PLReferences, PLSignatureHelp, - }, - compiler::{compile_dry, ActionType}, - diag::DiagCode, - range::Pos, +use crate::{ + ast::{ + accumulators::{ + Completions, Diagnostics, DocSymbols, GotoDef, Hints, PLFormat, PLHover, PLReferences, + PLSignatureHelp, }, - db::Database, - lsp::mem_docs::{MemDocs, MemDocsInput}, - Db, + compiler::{compile_dry, ActionType}, + diag::DiagCode, + range::Pos, + }, + db::Database, + lsp::mem_docs::{MemDocs, MemDocsInput}, + Db, +}; + +fn test_lsp<'db, A>( + db: &'db dyn Db, + params: Option<(Pos, Option)>, + action: ActionType, + src: &str, +) -> Vec<::Data> +where + A: Accumulator, + dyn Db + 'db: HasJar<::Jar>, +{ + let docs = MemDocs::default(); + let pos = if let Some((pos, _)) = params { + Some(pos) + } else { + None }; - fn test_lsp<'db, A>( - db: &'db dyn Db, - params: Option<(Pos, Option)>, - action: ActionType, - src: &str, - ) -> Vec<::Data> - where - A: Accumulator, - dyn Db + 'db: HasJar<::Jar>, - { - let docs = MemDocs::default(); - let pos = if let Some((pos, _)) = params { - Some(pos) + // let db = Database::default(); + let input = MemDocsInput::new( + db, + Arc::new(Mutex::new(RefCell::new(docs))), + src.to_string(), + Default::default(), + action, + params, + pos, + ); + compile_dry(db, input).unwrap(); + compile_dry::accumulated::(db, input) +} +#[test] +fn test_diag() { + let comps = test_lsp::( + &Database::default(), + None, + ActionType::Diagnostic, + "test/lsp_diag/test_diag.pi", + ); + assert!(!comps.is_empty()); + let (file, diag) = &comps[0]; + assert!(file.contains("test_diag.pi")); + let mut diag = diag.clone(); + diag.sort_by(|a, b| { + if a.raw.range.start.line < b.raw.range.start.line + || (a.raw.range.start.line == b.raw.range.start.line + && a.raw.range.start.column < b.raw.range.start.column) + { + std::cmp::Ordering::Less + } else if a.raw.range.start.line == b.raw.range.start.line + && a.raw.range.start.column == b.raw.range.start.column + { + std::cmp::Ordering::Equal } else { - None - }; + std::cmp::Ordering::Greater + } + }); + assert_eq!(diag.len(), 11); + assert_eq!( + new_diag_range(10, 14, 10, 15), + diag[0].get_range().to_diag_range() + ); + assert_eq!( + diag[0].get_diag_code(), + DiagCode::Err(crate::ast::diag::ErrorCode::TYPE_MISMATCH) + ); + assert_eq!( + new_diag_range(19, 16, 19, 18), + diag[1].get_range().to_diag_range() + ); + assert_eq!( + diag[1].get_diag_code(), + DiagCode::Err(crate::ast::diag::ErrorCode::TYPE_MISMATCH) + ); + assert_eq!( + new_diag_range(21, 12, 21, 21), + diag[2].get_range().to_diag_range() + ); + assert_eq!( + diag[2].get_diag_code(), + DiagCode::Err(crate::ast::diag::ErrorCode::INVALID_DIRECT_UNION_CAST) + ); + assert_eq!( + new_diag_range(22, 13, 22, 22), + diag[3].get_range().to_diag_range() + ); + assert_eq!( + diag[3].get_diag_code(), + DiagCode::Err(crate::ast::diag::ErrorCode::INVALID_UNION_CAST) + ); + assert_eq!( + new_diag_range(23, 18, 23, 21), + diag[4].get_range().to_diag_range() + ); + assert_eq!( + diag[4].get_diag_code(), + DiagCode::Err(crate::ast::diag::ErrorCode::UNION_DOES_NOT_CONTAIN_TYPE) + ); + assert_eq!( + new_diag_range(24, 13, 24, 21), + diag[5].get_range().to_diag_range() + ); + assert_eq!( + diag[5].get_diag_code(), + DiagCode::Err(crate::ast::diag::ErrorCode::INVALID_IS_EXPR) + ); + assert_eq!( + new_diag_range(28, 11, 28, 11), + diag[6].get_range().to_diag_range() + ); + assert_eq!( + diag[6].get_diag_code(), + DiagCode::Err(crate::ast::diag::ErrorCode::MISSING_SEMI) + ); + assert_eq!( + new_diag_range(30, 8, 30, 9), + diag[7].get_range().to_diag_range() + ); + assert_eq!( + diag[7].get_diag_code(), + DiagCode::Warn(crate::ast::diag::WarnCode::UNUSED_VARIABLE) + ); + assert_eq!( + new_diag_range(30, 13, 30, 13), + diag[8].get_range().to_diag_range() + ); + assert_eq!( + diag[8].get_diag_code(), + DiagCode::Err(crate::ast::diag::ErrorCode::MISSING_SEMI) + ); + assert_eq!( + new_diag_range(31, 15, 31, 15), + diag[9].get_range().to_diag_range() + ); + assert_eq!( + diag[9].get_diag_code(), + DiagCode::Err(crate::ast::diag::ErrorCode::MISSING_SEMI) + ); + assert_eq!( + new_diag_range(41, 5, 41, 7), + diag[10].get_range().to_diag_range() + ); + assert_eq!( + diag[10].get_diag_code(), + DiagCode::Err(crate::ast::diag::ErrorCode::DERIVE_TRAIT_NOT_IMPL) + ); +} +#[test] +fn test_memory_leak() { + let db = &mut Database::default(); + let docs = MemDocs::default(); + let params = Some(( + Pos { + line: 2, + column: 8, + offset: 0, + }, + None, + )); + let pos = if let Some((pos, _)) = params { + Some(pos) + } else { + None + }; - // let db = Database::default(); - let input = MemDocsInput::new( - db, - Arc::new(Mutex::new(RefCell::new(docs))), - src.to_string(), - Default::default(), - action, - params, - pos, - ); - compile_dry(db, input).unwrap(); - compile_dry::accumulated::(db, input) - } - #[test] - fn test_diag() { - let comps = test_lsp::( - &Database::default(), - None, - ActionType::Diagnostic, - "test/lsp_diag/test_diag.pi", - ); - assert!(!comps.is_empty()); - let (file, diag) = &comps[0]; - assert!(file.contains("test_diag.pi")); - let mut diag = diag.clone(); - diag.sort_by(|a, b| { - if a.raw.range.start.line < b.raw.range.start.line - || (a.raw.range.start.line == b.raw.range.start.line - && a.raw.range.start.column < b.raw.range.start.column) - { - std::cmp::Ordering::Less - } else if a.raw.range.start.line == b.raw.range.start.line - && a.raw.range.start.column == b.raw.range.start.column - { - std::cmp::Ordering::Equal - } else { - std::cmp::Ordering::Greater - } - }); - assert_eq!(diag.len(), 11); - assert_eq!( - new_diag_range(10, 14, 10, 15), - diag[0].get_range().to_diag_range() - ); - assert_eq!( - diag[0].get_diag_code(), - DiagCode::Err(crate::ast::diag::ErrorCode::TYPE_MISMATCH) - ); - assert_eq!( - new_diag_range(19, 16, 19, 18), - diag[1].get_range().to_diag_range() - ); - assert_eq!( - diag[1].get_diag_code(), - DiagCode::Err(crate::ast::diag::ErrorCode::TYPE_MISMATCH) - ); - assert_eq!( - new_diag_range(21, 12, 21, 21), - diag[2].get_range().to_diag_range() - ); - assert_eq!( - diag[2].get_diag_code(), - DiagCode::Err(crate::ast::diag::ErrorCode::INVALID_DIRECT_UNION_CAST) - ); - assert_eq!( - new_diag_range(22, 13, 22, 22), - diag[3].get_range().to_diag_range() - ); - assert_eq!( - diag[3].get_diag_code(), - DiagCode::Err(crate::ast::diag::ErrorCode::INVALID_UNION_CAST) - ); - assert_eq!( - new_diag_range(23, 18, 23, 21), - diag[4].get_range().to_diag_range() - ); - assert_eq!( - diag[4].get_diag_code(), - DiagCode::Err(crate::ast::diag::ErrorCode::UNION_DOES_NOT_CONTAIN_TYPE) - ); - assert_eq!( - new_diag_range(24, 13, 24, 21), - diag[5].get_range().to_diag_range() - ); - assert_eq!( - diag[5].get_diag_code(), - DiagCode::Err(crate::ast::diag::ErrorCode::INVALID_IS_EXPR) - ); - assert_eq!( - new_diag_range(28, 11, 28, 11), - diag[6].get_range().to_diag_range() - ); - assert_eq!( - diag[6].get_diag_code(), - DiagCode::Err(crate::ast::diag::ErrorCode::MISSING_SEMI) - ); - assert_eq!( - new_diag_range(30, 8, 30, 9), - diag[7].get_range().to_diag_range() - ); - assert_eq!( - diag[7].get_diag_code(), - DiagCode::Warn(crate::ast::diag::WarnCode::UNUSED_VARIABLE) - ); - assert_eq!( - new_diag_range(30, 13, 30, 13), - diag[8].get_range().to_diag_range() - ); - assert_eq!( - diag[8].get_diag_code(), - DiagCode::Err(crate::ast::diag::ErrorCode::MISSING_SEMI) - ); - assert_eq!( - new_diag_range(31, 15, 31, 15), - diag[9].get_range().to_diag_range() - ); - assert_eq!( - diag[9].get_diag_code(), - DiagCode::Err(crate::ast::diag::ErrorCode::MISSING_SEMI) - ); - assert_eq!( - new_diag_range(41, 5, 41, 7), - diag[10].get_range().to_diag_range() - ); - assert_eq!( - diag[10].get_diag_code(), - DiagCode::Err(crate::ast::diag::ErrorCode::DERIVE_TRAIT_NOT_IMPL) - ); + // let db = Database::default(); + let input = MemDocsInput::new( + db, + Arc::new(Mutex::new(RefCell::new(docs))), + "test/lsp/mod.pi".to_string(), + Default::default(), + ActionType::FindReferences, + params, + pos, + ); + let m = compile_dry(db, input).unwrap(); + let mod1 = m.plmod(db); + let path = crate::utils::canonicalize("test/lsp/mod.pi") + .unwrap() + .to_str() + .unwrap() + .to_string(); + input.docs(db).lock().unwrap().borrow_mut().change( + db, + new_diag_range(0, 0, 0, 0), + path, + "\n".repeat(2), + ); + input.set_action(db).to(ActionType::Diagnostic); + input.set_params(db).to(Some(( + Pos { + line: 1, + column: 1, + offset: 0, + }, + Some("\n".repeat(100)), + ))); + let m = compile_dry(db, input).unwrap(); + let mod2 = m.plmod(db); + assert_ne!(mod1, mod2); + let modstr1 = format!("{:?}", mod1); + let modstr2 = format!("{:?}", mod2); + assert_eq!(modstr1.len(), modstr2.len()); +} + +#[test] +fn test_struct_field_completion() { + let comps = test_lsp::( + &Database::default(), + Some(( + Pos { + line: 9, + column: 10, + offset: 0, + }, + Some(".".to_string()), + )), + ActionType::Completion, + "test/lsp/test_completion.pi", + ); + assert!(!comps.is_empty()); + assert_eq!(comps[0].len(), 3); + let compstr = vec!["a", "b", "c"]; + for comp in comps[0].iter() { + assert!(compstr.contains(&comp.label.as_str())); } - #[test] - fn test_memory_leak() { - let db = &mut Database::default(); - let docs = MemDocs::default(); - let params = Some(( +} + +#[test] +fn test_completion() { + let comps = test_lsp::( + &Database::default(), + Some(( Pos { - line: 2, - column: 8, + line: 10, + column: 6, offset: 0, }, None, - )); - let pos = if let Some((pos, _)) = params { - Some(pos) - } else { - None - }; - - // let db = Database::default(); - let input = MemDocsInput::new( - db, - Arc::new(Mutex::new(RefCell::new(docs))), - "test/lsp/mod.pi".to_string(), - Default::default(), - ActionType::FindReferences, - params, - pos, - ); - let m = compile_dry(db, input).unwrap(); - let mod1 = m.plmod(db); - let path = crate::utils::canonicalize("test/lsp/mod.pi") - .unwrap() - .to_str() - .unwrap() - .to_string(); - input.docs(db).lock().unwrap().borrow_mut().change( - db, - new_diag_range(0, 0, 0, 0), - path, - "\n".repeat(2), - ); - input.set_action(db).to(ActionType::Diagnostic); - input.set_params(db).to(Some(( + )), + ActionType::Completion, + "test/lsp/test_completion.pi", + ); + assert!(!comps.is_empty()); + let lables = comps[0].iter().map(|c| c.label.clone()).collect::>(); + assert!(lables.contains(&"test1".to_string())); + assert!(lables.contains(&"name".to_string())); + assert!(lables.contains(&"if".to_string())); +} +#[test] +fn test_type_completion() { + let comps = test_lsp::( + &Database::default(), + Some(( Pos { - line: 1, - column: 1, + line: 5, + column: 7, offset: 0, }, - Some("\n".repeat(100)), - ))); - let m = compile_dry(db, input).unwrap(); - let mod2 = m.plmod(db); - assert_ne!(mod1, mod2); - let modstr1 = format!("{:?}", mod1); - let modstr2 = format!("{:?}", mod2); - assert_eq!(modstr1.len(), modstr2.len()); - } - - #[test] - fn test_struct_field_completion() { - let comps = test_lsp::( - &Database::default(), - Some(( - Pos { - line: 9, - column: 10, - offset: 0, - }, - Some(".".to_string()), - )), - ActionType::Completion, - "test/lsp/test_completion.pi", - ); - assert!(!comps.is_empty()); - assert_eq!(comps[0].len(), 3); - let compstr = vec!["a", "b", "c"]; - for comp in comps[0].iter() { - assert!(compstr.contains(&comp.label.as_str())); - } - } - - #[test] - fn test_completion() { - let comps = test_lsp::( - &Database::default(), - Some(( - Pos { - line: 10, - column: 6, - offset: 0, - }, - None, - )), - ActionType::Completion, - "test/lsp/test_completion.pi", - ); - assert!(!comps.is_empty()); - let lables = comps[0].iter().map(|c| c.label.clone()).collect::>(); - assert!(lables.contains(&"test1".to_string())); - assert!(lables.contains(&"name".to_string())); - assert!(lables.contains(&"if".to_string())); - } - #[test] - fn test_type_completion() { - let comps = test_lsp::( - &Database::default(), - Some(( - Pos { - line: 5, - column: 7, - offset: 0, - }, - Some(":".to_string()), - )), - ActionType::Completion, - "test/lsp/test_completion.pi", - ); - assert!(!comps.is_empty()); - let lables = comps[0].iter().map(|c| c.label.clone()).collect::>(); - assert!(lables.contains(&"test".to_string())); // self refernece - assert!(lables.contains(&"i64".to_string())); - assert!(!lables.contains(&"name".to_string())); - assert!(lables.contains(&"test1".to_string())); - } + Some(":".to_string()), + )), + ActionType::Completion, + "test/lsp/test_completion.pi", + ); + assert!(!comps.is_empty()); + let lables = comps[0].iter().map(|c| c.label.clone()).collect::>(); + assert!(lables.contains(&"test".to_string())); // self refernece + assert!(lables.contains(&"i64".to_string())); + assert!(!lables.contains(&"name".to_string())); + assert!(lables.contains(&"test1".to_string())); +} - #[test] - fn test_st_field_completion() { - let comps = test_lsp::( - &Database::default(), - Some(( - Pos { - line: 37, - column: 8, - offset: 0, - }, - Some(":".to_string()), - )), - ActionType::Completion, - "test/lsp/test_completion.pi", - ); - assert!(!comps.is_empty()); - let lables = comps[0].to_vec(); - assert!( - lables - .iter() - .any(|c| c.label == "mod" && c.kind == Some(CompletionItemKind::MODULE)), - "mod not found in completion" - ); - } +#[test] +fn test_st_field_completion() { + let comps = test_lsp::( + &Database::default(), + Some(( + Pos { + line: 37, + column: 8, + offset: 0, + }, + Some(":".to_string()), + )), + ActionType::Completion, + "test/lsp/test_completion.pi", + ); + assert!(!comps.is_empty()); + let lables = comps[0].to_vec(); + assert!( + lables + .iter() + .any(|c| c.label == "mod" && c.kind == Some(CompletionItemKind::MODULE)), + "mod not found in completion" + ); +} - #[test] - fn test_st_field_exttp_completion() { - let comps = test_lsp::( - &Database::default(), - Some(( - Pos { - line: 38, - column: 8, - offset: 0, - }, - Some(":".to_string()), - )), - ActionType::Completion, - "test/lsp/test_completion.pi", - ); - assert!(!comps.is_empty()); - let lables = comps[0].to_vec(); - assert!( - lables - .iter() - .any(|c| c.label == "pubname" && c.kind == Some(CompletionItemKind::STRUCT)), - "`pubname` not found in completion" - ); +#[test] +fn test_st_field_exttp_completion() { + let comps = test_lsp::( + &Database::default(), + Some(( + Pos { + line: 38, + column: 8, + offset: 0, + }, + Some(":".to_string()), + )), + ActionType::Completion, + "test/lsp/test_completion.pi", + ); + assert!(!comps.is_empty()); + let lables = comps[0].to_vec(); + assert!( + lables + .iter() + .any(|c| c.label == "pubname" && c.kind == Some(CompletionItemKind::STRUCT)), + "`pubname` not found in completion" + ); +} +#[test] +fn test_hint() { + let hints = test_lsp::( + &Database::default(), + None, + ActionType::Hint, + "test/lsp/test_completion.pi", + ); + assert!(!hints.is_empty()); + assert!(!hints[0].is_empty()); + assert_eq!( + hints[0][0].label, + InlayHintLabel::String(": i64".to_string()) + ); +} +fn new_diag_range(sl: u32, sc: u32, el: u32, ec: u32) -> lsp_types::Range { + lsp_types::Range { + start: lsp_types::Position { + line: sl, + character: sc, + }, + end: lsp_types::Position { + line: el, + character: ec, + }, } - #[test] - fn test_hint() { - let hints = test_lsp::( - &Database::default(), +} +#[test] +fn test_goto_def() { + let def = test_lsp::( + &Database::default(), + Some(( + Pos { + line: 39, + column: 14, + offset: 0, + }, None, - ActionType::Hint, - "test/lsp/test_completion.pi", - ); - assert!(!hints.is_empty()); - assert!(!hints[0].is_empty()); - assert_eq!( - hints[0][0].label, - InlayHintLabel::String(": i64".to_string()) - ); + )), + ActionType::GotoDef, + "test/lsp/test_completion.pi", + ); + assert!(!def.is_empty()); + if let GotoDefinitionResponse::Scalar(sc) = def[0].clone() { + assert!(sc.uri.to_string().contains("test/lsp/mod.pi")); + assert_eq!(sc.range, new_diag_range(1, 7, 1, 11)); + } else { + panic!("expect goto def to be scalar, found {:?}", def[0]) } - fn new_diag_range(sl: u32, sc: u32, el: u32, ec: u32) -> lsp_types::Range { - lsp_types::Range { - start: lsp_types::Position { - line: sl, - character: sc, - }, - end: lsp_types::Position { - line: el, - character: ec, +} +#[test] +fn test_hover_struct() { + let hovers = test_lsp::( + &Database::default(), + Some(( + Pos { + line: 4, + column: 19, + offset: 0, }, - } - } - #[test] - fn test_goto_def() { - let def = test_lsp::( - &Database::default(), - Some(( - Pos { - line: 39, - column: 14, - offset: 0, - }, - None, - )), - ActionType::GotoDef, - "test/lsp/test_completion.pi", - ); - assert!(!def.is_empty()); - if let GotoDefinitionResponse::Scalar(sc) = def[0].clone() { - assert!(sc.uri.to_string().contains("test/lsp/mod.pi")); - assert_eq!(sc.range, new_diag_range(1, 7, 1, 11)); - } else { - panic!("expect goto def to be scalar, found {:?}", def[0]) - } - } - #[test] - fn test_hover_struct() { - let hovers = test_lsp::( - &Database::default(), - Some(( - Pos { - line: 4, - column: 19, - offset: 0, - }, - None, - )), - ActionType::Hover, - "test/lsp/mod2.pi", - ); - assert!(!hovers.is_empty()); - if let HoverContents::Array(v) = hovers[0].clone().contents { - if let MarkedString::String(st) = v[0].clone() { - assert_eq!(st.trim(), "# content".to_string()); - } else { - panic!("expect hover to be string, found {:?}", hovers[0]) - } + None, + )), + ActionType::Hover, + "test/lsp/mod2.pi", + ); + assert!(!hovers.is_empty()); + if let HoverContents::Array(v) = hovers[0].clone().contents { + if let MarkedString::String(st) = v[0].clone() { + assert_eq!(st.trim(), "# content".to_string()); } else { - panic!("expect goto def to be scalar, found {:?}", hovers[0]) + panic!("expect hover to be string, found {:?}", hovers[0]) } + } else { + panic!("expect goto def to be scalar, found {:?}", hovers[0]) } +} - #[test] - fn test_sig_help() { - let hovers = test_lsp::( - &Database::default(), - Some(( - Pos { - line: 11, - column: 19, - offset: 0, - }, - None, - )), - ActionType::SignatureHelp, - "test/lsp/mod2.pi", - ); - assert!(!hovers.is_empty()); - assert!( - hovers[0] - .signatures - .iter() - .find(|s| { - s.label == "test_sig_help(i: i64, ii: bool)" && s.active_parameter == Some(0) - }) - .is_some(), - "expect to find test_sig_help(i: i64, ii: bool) with active parameter 0, found {:?}", - hovers[0] - ); - } - - #[test] - fn test_find_refs() { - let refs = test_lsp::( - &Database::default(), - Some(( - Pos { - line: 2, - column: 8, - offset: 0, - }, - None, - )), - ActionType::FindReferences, - "test/lsp/mod.pi", - ); - assert!(!refs.is_empty()); - let mut locs = vec![]; - for r in refs.iter() { - for l in r.iter() { - locs.push(l.clone()); - } - } - // assert_eq!(locs.len(), 3); - assert!(locs - .iter() - .find(|l| { - let ok = l.uri.to_string().contains("test/lsp/mod.pi"); - if ok { - assert!(l.range == new_diag_range(1, 7, 1, 11)) - } - ok - }) - .is_some()); - assert!(locs - .iter() - .find(|l| { - let ok = l.uri.to_string().contains("test/lsp/test_completion.pi"); - if ok { - assert!(l.range == new_diag_range(38, 11, 38, 15)) - } - ok - }) - .is_some()); - assert!(locs +#[test] +fn test_sig_help() { + let hovers = test_lsp::( + &Database::default(), + Some(( + Pos { + line: 11, + column: 19, + offset: 0, + }, + None, + )), + ActionType::SignatureHelp, + "test/lsp/mod2.pi", + ); + assert!(!hovers.is_empty()); + assert!( + hovers[0] + .signatures .iter() - .find(|l| { - let ok = l.uri.to_string().contains("test/lsp/mod2.pi"); - if ok { - assert!(l.range == new_diag_range(3, 17, 3, 21)) - } - ok + .find(|s| { + s.label == "test_sig_help(i: i64, ii: bool)" && s.active_parameter == Some(0) }) - .is_some()); - } + .is_some(), + "expect to find test_sig_help(i: i64, ii: bool) with active parameter 0, found {:?}", + hovers[0] + ); +} - #[test] - fn test_doc_symbol() { - let symbols = test_lsp::( - &Database::default(), +#[test] +fn test_find_refs() { + let refs = test_lsp::( + &Database::default(), + Some(( + Pos { + line: 2, + column: 8, + offset: 0, + }, None, - ActionType::DocSymbol, - "test/lsp/test_completion.pi", - ); - assert!(!symbols.is_empty()); - assert!(!symbols[0].is_empty()); - let testst = symbols[0].iter().filter(|s| s.name == "test").last(); - assert!(testst.is_some(), "test struct not found"); - assert_eq!( - testst.unwrap().kind, - lsp_types::SymbolKind::STRUCT, - "expect test's type to be struct, found {:?}", - testst.unwrap().kind - ); - let expect = new_diag_range(0, 0, 5, 1); - assert_eq!( - testst.unwrap().range, - expect, - "expect test's range to be {:?}, found {:?}", - expect, - testst.unwrap().range - ); - let name1fn = symbols[0].iter().filter(|s| s.name == "name1").last(); - assert_eq!( - name1fn.unwrap().kind, - lsp_types::SymbolKind::FUNCTION, - "expect name1's type to be struct, found {:?}", - name1fn.unwrap().kind - ); - let expect = new_diag_range(26, 0, 29, 1); - assert_eq!( - name1fn.unwrap().range, - expect, - "expect name1's range to be {:?}, found {:?}", - expect, - name1fn.unwrap().range - ); + )), + ActionType::FindReferences, + "test/lsp/mod.pi", + ); + assert!(!refs.is_empty()); + let mut locs = vec![]; + for r in refs.iter() { + for l in r.iter() { + locs.push(l.clone()); + } } + // assert_eq!(locs.len(), 3); + assert!(locs + .iter() + .find(|l| { + let ok = l.uri.to_string().contains("test/lsp/mod.pi"); + if ok { + assert!(l.range == new_diag_range(1, 7, 1, 11)) + } + ok + }) + .is_some()); + assert!(locs + .iter() + .find(|l| { + let ok = l.uri.to_string().contains("test/lsp/test_completion.pi"); + if ok { + assert!(l.range == new_diag_range(38, 11, 38, 15)) + } + ok + }) + .is_some()); + assert!(locs + .iter() + .find(|l| { + let ok = l.uri.to_string().contains("test/lsp/mod2.pi"); + if ok { + assert!(l.range == new_diag_range(3, 17, 3, 21)) + } + ok + }) + .is_some()); +} - #[test] - #[cfg(feature = "jit")] - fn test_jit() { - use crate::ast::compiler::{compile, Options}; - use std::path::PathBuf; - let _l = crate::utils::plc_new::tests::TEST_COMPILE_MUTEX - .lock() - .unwrap(); - let out = "testjitout"; - let docs = MemDocs::default(); - let db = Database::default(); - let input = MemDocsInput::new( - &db, - Arc::new(Mutex::new(RefCell::new(docs))), - "test/main.pi".to_string(), - Default::default(), - ActionType::Compile, - None, - None, - ); - let outplb = "testjitout.bc"; - compile( - &db, - input, - out.to_string(), - Options { - optimization: crate::ast::compiler::HashOptimizationLevel::None, - genir: true, - printast: false, - flow: false, - fmt: false, - jit: true, - }, - ); - assert!( - crate::ast::compiler::run( - PathBuf::from(outplb).as_path(), - inkwell::OptimizationLevel::None, - ) == 0, - "jit compiled program exit with non-zero status" - ); - } - #[test] - fn test_compile() { - let out = "testout"; - let exe = PathBuf::from(out); - #[cfg(target_os = "windows")] - let exe = exe.with_extension("exe"); - _ = remove_file(&exe); - let _l = crate::utils::plc_new::tests::TEST_COMPILE_MUTEX - .lock() - .unwrap(); - use std::{path::PathBuf, process::Command}; +#[test] +fn test_doc_symbol() { + let symbols = test_lsp::( + &Database::default(), + None, + ActionType::DocSymbol, + "test/lsp/test_completion.pi", + ); + assert!(!symbols.is_empty()); + assert!(!symbols[0].is_empty()); + let testst = symbols[0].iter().filter(|s| s.name == "test").last(); + assert!(testst.is_some(), "test struct not found"); + assert_eq!( + testst.unwrap().kind, + lsp_types::SymbolKind::STRUCT, + "expect test's type to be struct, found {:?}", + testst.unwrap().kind + ); + let expect = new_diag_range(0, 0, 5, 1); + assert_eq!( + testst.unwrap().range, + expect, + "expect test's range to be {:?}, found {:?}", + expect, + testst.unwrap().range + ); + let name1fn = symbols[0].iter().filter(|s| s.name == "name1").last(); + assert_eq!( + name1fn.unwrap().kind, + lsp_types::SymbolKind::FUNCTION, + "expect name1's type to be struct, found {:?}", + name1fn.unwrap().kind + ); + let expect = new_diag_range(26, 0, 29, 1); + assert_eq!( + name1fn.unwrap().range, + expect, + "expect name1's range to be {:?}, found {:?}", + expect, + name1fn.unwrap().range + ); +} - use crate::ast::compiler::{compile, Options}; +#[test] +#[cfg(feature = "jit")] +fn test_jit() { + use crate::ast::compiler::{compile, Options}; + use std::path::PathBuf; + let _l = crate::utils::plc_new::tests::TEST_COMPILE_MUTEX + .lock() + .unwrap(); + let out = "testjitout"; + let docs = MemDocs::default(); + let db = Database::default(); + let input = MemDocsInput::new( + &db, + Arc::new(Mutex::new(RefCell::new(docs))), + "test/main.pi".to_string(), + Default::default(), + ActionType::Compile, + None, + None, + ); + let outplb = "testjitout.bc"; + compile( + &db, + input, + out.to_string(), + Options { + optimization: crate::ast::compiler::HashOptimizationLevel::None, + genir: true, + printast: false, + flow: false, + fmt: false, + jit: true, + }, + ); + assert!( + crate::ast::compiler::run( + PathBuf::from(outplb).as_path(), + inkwell::OptimizationLevel::None, + ) == 0, + "jit compiled program exit with non-zero status" + ); +} +#[test] +fn test_compile() { + let out = "testout"; + let exe = PathBuf::from(out); + #[cfg(target_os = "windows")] + let exe = exe.with_extension("exe"); + _ = remove_file(&exe); + let _l = crate::utils::plc_new::tests::TEST_COMPILE_MUTEX + .lock() + .unwrap(); + use std::{path::PathBuf, process::Command}; - let docs = MemDocs::default(); - let mut db = Database::default(); - let input = MemDocsInput::new( - &db, - Arc::new(Mutex::new(RefCell::new(docs))), - "test/main.pi".to_string(), - Default::default(), - ActionType::Compile, - None, - None, - ); - compile( - &db, - input, - out.to_string(), - Options { - optimization: crate::ast::compiler::HashOptimizationLevel::None, - genir: true, - printast: false, - flow: false, - fmt: false, - jit: false, - }, - ); - let exe = crate::utils::canonicalize(&exe) - .unwrap_or_else(|_| panic!("static compiled file not found {:?}", exe)); - let o = Command::new(exe.to_str().unwrap()) - .output() - .expect("failed to execute compiled program"); - assert!( - o.status.success(), - "static compiled program failed with status {:?} and output {:?} and error {:?}", - o.status, - String::from_utf8_lossy(&o.stdout), - String::from_utf8_lossy(&o.stderr) - ); - input.set_action(&mut db).to(ActionType::PrintAst); - compile( - &db, - input, - out.to_string(), - Options { - optimization: crate::ast::compiler::HashOptimizationLevel::Aggressive, - genir: false, - printast: true, - flow: false, - fmt: false, - jit: true, + use crate::ast::compiler::{compile, Options}; + + let docs = MemDocs::default(); + let mut db = Database::default(); + let input = MemDocsInput::new( + &db, + Arc::new(Mutex::new(RefCell::new(docs))), + "test/main.pi".to_string(), + Default::default(), + ActionType::Compile, + None, + None, + ); + compile( + &db, + input, + out.to_string(), + Options { + optimization: crate::ast::compiler::HashOptimizationLevel::None, + genir: true, + printast: false, + flow: false, + fmt: false, + jit: false, + }, + ); + let exe = crate::utils::canonicalize(&exe) + .unwrap_or_else(|_| panic!("static compiled file not found {:?}", exe)); + let o = Command::new(exe.to_str().unwrap()) + .output() + .expect("failed to execute compiled program"); + assert!( + o.status.success(), + "static compiled program failed with status {:?} and output {:?} and error {:?}", + o.status, + String::from_utf8_lossy(&o.stdout), + String::from_utf8_lossy(&o.stderr) + ); + input.set_action(&mut db).to(ActionType::PrintAst); + compile( + &db, + input, + out.to_string(), + Options { + optimization: crate::ast::compiler::HashOptimizationLevel::Aggressive, + genir: false, + printast: true, + flow: false, + fmt: false, + jit: true, + }, + ); + test_lsp::( + &Database::default(), + Some(( + Pos { + line: 10, + column: 6, + offset: 0, }, - ); - test_lsp::( - &Database::default(), - Some(( - Pos { - line: 10, - column: 6, - offset: 0, - }, - None, - )), - ActionType::LspFmt, - "test/main.pi", - ); - } + None, + )), + ActionType::LspFmt, + "test/main.pi", + ); +} - #[test] - fn test_fmt() { - let testfile = "test/fmt/test_fmt.pi"; - let text_edit = - test_lsp::(&Database::default(), None, ActionType::LspFmt, testfile); - debug_assert!(text_edit[0].is_empty()); - } +#[test] +fn test_fmt() { + let testfile = "test/fmt/test_fmt.pi"; + let text_edit = test_lsp::(&Database::default(), None, ActionType::LspFmt, testfile); + debug_assert!(text_edit[0].is_empty()); } diff --git a/src/nomparser/types.rs b/src/nomparser/types.rs index b1e4b49e8..1be96881d 100644 --- a/src/nomparser/types.rs +++ b/src/nomparser/types.rs @@ -28,7 +28,7 @@ pub fn type_name(input: Span) -> IResult> { delspace(map_res( pair( many0(tag_token_symbol(TokenType::TAKE_VAL)), - alt((basic_type, array_type, tuple_type)), + alt((basic_type, array_type, closure_type, tuple_type)), ), |(pts, n)| { let mut node = n; diff --git a/test/fmt/test_fmt.pi b/test/fmt/test_fmt.pi index 6008676a4..2e68c9d73 100644 --- a/test/fmt/test_fmt.pi +++ b/test/fmt/test_fmt.pi @@ -191,3 +191,34 @@ fn test_tuple() void { return; } +fn fn_as_param(f: (i64, i64) => i64) (i64, i64) => i64 { + panic::assert(f(1, 2) == 3); + return f; +} + +fn add(i: i64, j: i64) i64 { + return i + j; +} + +pub fn test_fntype() () => f { + let f = add; + let fc: (i64, i64) => i64 = f; + let re = fn_as_param(fc)(2, 2); + panic::assert(re == 4); + let ff: () => f = test_ret_f; + let f1 = test_ret_f().f(100, 2); + panic::assert(f1 == 102); + return ff; +} + +fn test_ret_f() f { + let d: (i64, i64) => i64 = add; + return f{ + f: d + }; +} + +pub struct f { + f: (i64, i64) => i64; +} + diff --git a/test/main.pi b/test/main.pi index 04b0856ae..aac1b5096 100644 --- a/test/main.pi +++ b/test/main.pi @@ -13,6 +13,7 @@ use project1::test::macros; use project1::test::union; use project1::test::multi_trait; use project1::test::tuple; +use project1::test::fntype; use pl_test::main; pub fn main() i64 { @@ -32,6 +33,7 @@ pub fn main() i64 { union::test_union(); multi_trait::test_multi_trait(); tuple::test_tuple(); + fntype::test_fntype()(); return 0; } diff --git a/test/test/fntype.pi b/test/test/fntype.pi new file mode 100644 index 000000000..9e9ffdfec --- /dev/null +++ b/test/test/fntype.pi @@ -0,0 +1,32 @@ +use core::panic; +fn fn_as_param(f: (i64, i64) => i64) (i64, i64) => i64 { + panic::assert(f(1, 2) == 3); + return f; +} + +fn add(i: i64, j: i64) i64 { + return i + j; +} + +pub fn test_fntype() () => f { + let f = add; + let fc: (i64, i64) => i64 = f; + let re = fn_as_param(fc)(2, 2); + panic::assert(re == 4); + let ff: () => f = test_ret_f; + let f1 = test_ret_f().f(100, 2); + panic::assert(f1 == 102); + return ff; +} + +fn test_ret_f() f { + let d: (i64, i64) => i64 = add; + return f{ + f: d + }; +} + +pub struct f { + f: (i64, i64) => i64; +} + From 659ba29652d966a9d1465f4d90ab5b66770d7067 Mon Sep 17 00:00:00 2001 From: bobxli Date: Thu, 4 May 2023 11:28:15 +0800 Subject: [PATCH 3/3] fix: missing closure assign tp check & fn auto cast --- src/ast/ctx.rs | 20 +++++++++++++++++++- src/ast/diag.rs | 1 + src/ast/node/statement.rs | 20 ++++++++++++++++++-- src/ast/pltype.rs | 16 ++++++++++++++++ 4 files changed, 54 insertions(+), 3 deletions(-) diff --git a/src/ast/ctx.rs b/src/ast/ctx.rs index c067adf33..873634100 100644 --- a/src/ast/ctx.rs +++ b/src/ast/ctx.rs @@ -23,6 +23,7 @@ use super::traits::CustomType; use crate::ast::builder::BuilderEnum; use crate::ast::builder::IRBuilder; +use crate::format_label; use crate::lsp::semantic_tokens::type_index; use crate::mismatch_err; @@ -221,7 +222,24 @@ impl<'a, 'ctx> Ctx<'a> { ori_value: usize, builder: &'b BuilderEnum<'a, 'ctx>, ) -> Result { - if let PLType::Closure(_) = &*target_pltype.borrow() { + if let (PLType::Closure(c), PLType::Fn(f)) = + (&*target_pltype.borrow(), &*ori_pltype.borrow()) + { + if f.to_closure_ty(self, builder) != *c { + return Err(ori_range + .new_err(ErrorCode::FUNCTION_TYPE_NOT_MATCH) + .add_label( + target_range, + self.get_file(), + format_label!("expected type `{}`", c.get_name()), + ) + .add_label( + ori_range, + self.get_file(), + format_label!("found type `{}`", f.to_closure_ty(self, builder).get_name()), + ) + .add_to_ctx(self)); + } if ori_value == usize::MAX { return Err(ori_range .new_err(ErrorCode::CANNOT_ASSIGN_INCOMPLETE_GENERICS) diff --git a/src/ast/diag.rs b/src/ast/diag.rs index 32e032fcc..30dc078ca 100644 --- a/src/ast/diag.rs +++ b/src/ast/diag.rs @@ -122,6 +122,7 @@ define_error!( METHOD_NOT_FOUND = "method not found", DERIVE_TRAIT_NOT_IMPL = "derive trait not impl", CANNOT_ASSIGN_INCOMPLETE_GENERICS = "cannot assign incomplete generic function to variable", + FUNCTION_TYPE_NOT_MATCH = "function type not match", ); macro_rules! define_warn { ($( diff --git a/src/ast/node/statement.rs b/src/ast/node/statement.rs index 8f5519f33..dd0d61d1e 100644 --- a/src/ast/node/statement.rs +++ b/src/ast/node/statement.rs @@ -64,12 +64,28 @@ impl Node for DefNode { return Err(ctx.add_diag(self.range.new_err(ErrorCode::UNDEFINED_TYPE))); } let re = re.unwrap(); - let tp = re.get_ty(); + let mut tp = re.get_ty(); + let v = if let PLType::Fn(f) = &*tp.clone().borrow() { + let oritp = tp; + let c = Arc::new(RefCell::new(PLType::Closure(f.to_closure_ty(ctx, builder)))); + tp = c.clone(); + ctx.up_cast( + c, + oritp, + Default::default(), + Default::default(), + re.get_value(), + builder, + ) + .unwrap() + } else { + re.get_value() + }; if pltype.is_none() { ctx.push_type_hints(self.var.range, tp.clone()); pltype = Some(tp); } - expv = Some(re.get_value()); + expv = Some(v); } let pltype = pltype.unwrap(); let ptr2value = builder.alloc( diff --git a/src/ast/pltype.rs b/src/ast/pltype.rs index 7f027f587..924ff12f9 100644 --- a/src/ast/pltype.rs +++ b/src/ast/pltype.rs @@ -600,6 +600,22 @@ impl TryFrom for FNValue { } } impl FNValue { + pub fn to_closure_ty<'a, 'ctx, 'b>( + &self, + ctx: &'b mut Ctx<'a>, + builder: &'b BuilderEnum<'a, 'ctx>, + ) -> ClosureType { + return ClosureType { + range: Default::default(), + ret_type: self.fntype.ret_pltype.get_type(ctx, builder).unwrap(), + arg_types: self + .fntype + .param_pltypes + .iter() + .map(|x| x.get_type(ctx, builder).unwrap()) + .collect(), + }; + } pub fn is_modified_by(&self, modifier: TokenType) -> bool { if let Some((t, _)) = self.fntype.modifier { t == modifier