Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions src/ast/builder/llvmbuilder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use inkwell::{
};
use rustc_hash::FxHashMap;

use crate::ast::{diag::PLDiag, pass::run_immix_pass};
use crate::ast::{diag::PLDiag, pass::run_immix_pass, pltype::ClosureType};

use super::{
super::{
Expand Down Expand Up @@ -587,6 +587,22 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> {
fn_type
})
}

fn get_closure_fn_type(&self, closure: &ClosureType, ctx: &mut Ctx<'a>) -> FunctionType<'ctx> {
let params = closure
.arg_types
.iter()
.map(|pltype| {
let tp = self.get_basic_type_op(&pltype.borrow(), ctx).unwrap();
let tp: BasicMetadataTypeEnum = tp.into();
tp
})
.collect::<Vec<_>>();
let fn_type = self
.get_ret_type(&closure.ret_type.borrow(), ctx)
.fn_type(&params, false);
fn_type
}
/// # get_basic_type_op
/// get the basic type of the type
/// used in code generation
Expand Down Expand Up @@ -648,6 +664,19 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> {
];
Some(self.context.struct_type(&fields, false).into())
}
PLType::Closure(c) => {
// all closures are represented as a struct with a function pointer and an i8ptr(point to closure data)
let fields = vec![
self.get_closure_fn_type(c, ctx)
.ptr_type(AddressSpace::default())
.into(),
self.context
.i8_type()
.ptr_type(AddressSpace::default())
.into(),
];
Some(self.context.struct_type(&fields, false).into())
}
}
}
/// # get_ret_type
Expand Down Expand Up @@ -936,6 +965,7 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> {
);
Some(tp.as_type())
}
PLType::Closure(_) => self.get_ditype(&PLType::Primitive(PriType::I64), ctx), // TODO
}
}

Expand Down Expand Up @@ -1122,6 +1152,9 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> {
| AnyValueEnum::PointerValue(_)
| AnyValueEnum::StructValue(_)
| AnyValueEnum::VectorValue(_) => handle,
AnyValueEnum::FunctionValue(f) => {
return Ok(self.get_llvm_value_handle(&f.as_global_value().as_any_value_enum()));
}
_ => return Err(ctx.add_diag(range.new_err(ErrorCode::EXPECT_VALUE))),
})
} else {
Expand Down Expand Up @@ -1251,8 +1284,15 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> {
let value = self.get_llvm_value(value).unwrap();
let ptr = self.get_llvm_value(ptr).unwrap();
let ptr = ptr.into_pointer_value();
self.builder
.build_store::<BasicValueEnum>(ptr, value.try_into().unwrap());
let value = if value.is_function_value() {
value
.into_function_value()
.as_global_value()
.as_basic_value_enum()
} else {
value.try_into().unwrap()
};
self.builder.build_store(ptr, value);
}
fn build_const_in_bounds_gep(
&self,
Expand Down
76 changes: 55 additions & 21 deletions src/ast/ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use super::traits::CustomType;

use crate::ast::builder::BuilderEnum;
use crate::ast::builder::IRBuilder;
use crate::format_label;
use crate::lsp::semantic_tokens::type_index;

use crate::mismatch_err;
Expand Down Expand Up @@ -214,14 +215,46 @@ impl<'a, 'ctx> Ctx<'a> {
}
pub fn up_cast<'b>(
&mut self,
trait_pltype: Arc<RefCell<PLType>>,
st_pltype: Arc<RefCell<PLType>>,
trait_range: Range,
st_range: Range,
st_value: usize,
target_pltype: Arc<RefCell<PLType>>,
ori_pltype: Arc<RefCell<PLType>>,
target_range: Range,
ori_range: Range,
ori_value: usize,
builder: &'b BuilderEnum<'a, 'ctx>,
) -> Result<usize, PLDiag> {
if let PLType::Union(u) = &*trait_pltype.borrow() {
if let (PLType::Closure(c), PLType::Fn(f)) =
(&*target_pltype.borrow(), &*ori_pltype.borrow())
{
if f.to_closure_ty(self, builder) != *c {
return Err(ori_range
.new_err(ErrorCode::FUNCTION_TYPE_NOT_MATCH)
.add_label(
target_range,
self.get_file(),
format_label!("expected type `{}`", c.get_name()),
)
.add_label(
ori_range,
self.get_file(),
format_label!("found type `{}`", f.to_closure_ty(self, builder).get_name()),
)
.add_to_ctx(self));
}
if ori_value == usize::MAX {
return Err(ori_range
.new_err(ErrorCode::CANNOT_ASSIGN_INCOMPLETE_GENERICS)
.add_help("try add generic type explicitly to fix this error.")
.add_to_ctx(self));
}
let closure_v = builder.alloc("tmp", &target_pltype.borrow(), self, None);
let closure_f = builder.build_struct_gep(closure_v, 0, "closure_f").unwrap();
let ori_value = builder.try_load2var(ori_range, ori_value, self)?;
// TODO now, we only handle the case that the closure is a pure function.
// TODO the real closure case is leave to the future.
builder.build_store(closure_f, ori_value);
return Ok(closure_v);
}
if let PLType::Union(u) = &*target_pltype.borrow() {
let union_members = self.run_in_type_mod(u, |ctx, u| {
let mut union_members = vec![];
for tp in &u.sum_types {
Expand All @@ -231,9 +264,9 @@ impl<'a, 'ctx> Ctx<'a> {
Ok(union_members)
})?;
for (i, tp) in union_members.iter().enumerate() {
if *tp.borrow() == *st_pltype.borrow() {
if *tp.borrow() == *ori_pltype.borrow() {
let union_handle =
builder.alloc("tmp_unionv", &trait_pltype.borrow(), self, None);
builder.alloc("tmp_unionv", &target_pltype.borrow(), self, None);
let union_value = builder
.build_struct_gep(union_handle, 1, "union_value")
.unwrap();
Expand All @@ -242,11 +275,11 @@ impl<'a, 'ctx> Ctx<'a> {
.unwrap();
let union_type = builder.int_value(&PriType::U64, i as u64, false);
builder.build_store(union_type_field, union_type);
let mut ptr = st_value;
if !builder.is_ptr(st_value) {
let mut ptr = ori_value;
if !builder.is_ptr(ori_value) {
// mv to heap
ptr = builder.alloc("tmp", &st_pltype.borrow(), self, None);
builder.build_store(ptr, st_value);
ptr = builder.alloc("tmp", &ori_pltype.borrow(), self, None);
builder.build_store(ptr, ori_value);
}
let st_value = builder.bitcast(
self,
Expand All @@ -260,20 +293,20 @@ impl<'a, 'ctx> Ctx<'a> {
}
}
}
let (st_pltype, st_value) = self.auto_deref(st_pltype, st_value, builder);
let (st_pltype, st_value) = self.auto_deref(ori_pltype, ori_value, builder);
if let (PLType::Trait(t), PLType::Struct(st)) =
(&*trait_pltype.borrow(), &*st_pltype.borrow())
(&*target_pltype.borrow(), &*st_pltype.borrow())
{
if !st.implements_trait(t, &self.plmod) {
return Err(mismatch_err!(
self,
st_range,
trait_range,
trait_pltype.borrow(),
ori_range,
target_range,
target_pltype.borrow(),
st_pltype.borrow()
));
}
let trait_handle = builder.alloc("tmp_traitv", &trait_pltype.borrow(), self, None);
let trait_handle = builder.alloc("tmp_traitv", &target_pltype.borrow(), self, None);
for f in t.list_trait_fields().iter() {
let mthd = st.find_method(self, &f.name).unwrap();
let fnhandle = builder.get_or_insert_fn_handle(&mthd, self);
Expand Down Expand Up @@ -303,9 +336,9 @@ impl<'a, 'ctx> Ctx<'a> {
#[allow(clippy::needless_return)]
return Err(mismatch_err!(
self,
st_range,
trait_range,
trait_pltype.borrow(),
ori_range,
target_range,
target_pltype.borrow(),
st_pltype.borrow()
));
}
Expand Down Expand Up @@ -807,6 +840,7 @@ impl<'a, 'ctx> Ctx<'a> {
PLType::Pointer(_) => unreachable!(),
PLType::PlaceHolder(_) => CompletionItemKind::STRUCT,
PLType::Union(_) => CompletionItemKind::ENUM,
PLType::Closure(_) => unreachable!(),
};
if k.starts_with('|') {
// skip method
Expand Down
7 changes: 6 additions & 1 deletion src/ast/diag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ define_error!(
INVALID_IS_EXPR = "invalid `is` expression",
INVALID_CAST = "invalid cast",
METHOD_NOT_FOUND = "method not found",
DERIVE_TRAIT_NOT_IMPL = "derive trait not impl"
DERIVE_TRAIT_NOT_IMPL = "derive trait not impl",
CANNOT_ASSIGN_INCOMPLETE_GENERICS = "cannot assign incomplete generic function to variable",
FUNCTION_TYPE_NOT_MATCH = "function type not match",
);
macro_rules! define_warn {
($(
Expand Down Expand Up @@ -397,6 +399,9 @@ impl PLDiag {
file: String,
txt: Option<(String, Vec<String>)>,
) -> &mut Self {
if range == Default::default() {
return self;
}
self.raw.labels.push(PLLabel { file, txt, range });
self
}
Expand Down
20 changes: 18 additions & 2 deletions src/ast/fmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ use super::{
string_literal::StringNode,
tuple::{TupleInitNode, TupleTypeNode},
types::{
ArrayInitNode, ArrayTypeNameNode, GenericDefNode, GenericParamNode, PointerTypeNode,
StructDefNode, StructInitFieldNode, StructInitNode, TypeNameNode, TypedIdentifierNode,
ArrayInitNode, ArrayTypeNameNode, ClosureTypeNode, GenericDefNode, GenericParamNode,
PointerTypeNode, StructDefNode, StructInitFieldNode, StructInitNode, TypeNameNode,
TypedIdentifierNode,
},
union::UnionDefNode,
FmtTrait, NodeEnum, TypeNodeEnum,
Expand Down Expand Up @@ -768,4 +769,19 @@ impl FmtBuilder {
}
self.r_paren();
}
pub fn parse_closure_type_node(&mut self, node: &ClosureTypeNode) {
self.l_paren();
for (i, ty) in node.arg_types.iter().enumerate() {
if i > 0 {
self.comma();
self.space();
}
ty.format(self);
}
self.r_paren();
self.space();
self.token("=>");
self.space();
node.ret_type.format(self);
}
}
Loading