Skip to content

Commit

Permalink
feat: add try_from trait impl for inner AstStatement structs
Browse files Browse the repository at this point in the history
  • Loading branch information
mhasel committed May 2, 2024
1 parent 7807d9d commit cf35575
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 74 deletions.
1 change: 1 addition & 0 deletions compiler/plc_ast/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ plc_util = { path = "../plc_util" }
plc_source = {path = "../plc_source"}
chrono = { version = "0.4", default-features = false }
serde = { version = "1.0", features = ["derive"] }
derive_more = { version = "0.99.0", features = ["try_into"] }
54 changes: 20 additions & 34 deletions compiler/plc_ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::{
ops::Range,
};

use derive_more::TryInto;
use serde::{Deserialize, Serialize};

use crate::{
Expand Down Expand Up @@ -593,7 +594,8 @@ pub struct AstNode {
pub location: SourceLocation,
}

#[derive(Debug, Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq, TryInto)]
#[try_into(ref)]
pub enum AstStatement {
EmptyStatement(EmptyStatement),
// a placeholder that indicates a default value of a datatype
Expand Down Expand Up @@ -623,13 +625,25 @@ pub enum AstStatement {
ControlStatement(AstControlStatement),

CaseCondition(Box<AstNode>),
#[try_into(ignore)]
ExitStatement(()),
#[try_into(ignore)]
ContinueStatement(()),
ReturnStatement(ReturnStatement),
JumpStatement(JumpStatement),
LabelStatement(LabelStatement),
}

#[macro_export]
/// A `try_from` convenience wrapper for `AstNode`, passed as the `ex:expr` argument.
/// Will try to return a reference to the variants inner type, specified via the `t:ty` parameter.
/// Converts the `try_from`-`Result` into an `Option`
macro_rules! try_from {
($ex:expr, $t:ty) => {
<&$t>::try_from($ex.get_stmt()).ok()
};
}

impl Debug for AstNode {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match &self.stmt {
Expand Down Expand Up @@ -754,11 +768,7 @@ impl Debug for AstNode {
impl AstNode {
///Returns the statement in a singleton list, or the contained statements if the statement is already a list
pub fn get_as_list(&self) -> Vec<&AstNode> {
if let AstStatement::ExpressionList(expressions) = &self.stmt {
expressions.iter().collect::<Vec<&AstNode>>()
} else {
vec![self]
}
try_from!(self, Vec<AstNode>).map(|it| it.iter().collect()).unwrap_or(vec![self])
}

pub fn get_location(&self) -> SourceLocation {
Expand Down Expand Up @@ -819,17 +829,7 @@ impl AstNode {

/// Returns true if the current statement is a flat reference (e.g. `a`)
pub fn is_flat_reference(&self) -> bool {
matches!(self.stmt, AstStatement::Identifier(..)) || {
if let AstStatement::ReferenceExpr(
ReferenceExpr { access: ReferenceAccess::Member(reference), base: None },
..,
) = &self.stmt
{
matches!(reference.as_ref().stmt, AstStatement::Identifier(..))
} else {
false
}
}
self.get_flat_reference_name().is_some()
}

/// Returns the reference-name if this is a flat reference like `a`, or None if this is no flat reference
Expand All @@ -845,10 +845,7 @@ impl AstNode {
}

pub fn get_label_name(&self) -> Option<&str> {
match &self.stmt {
AstStatement::LabelStatement(LabelStatement { name, .. }) => Some(name.as_str()),
_ => None,
}
try_from!(self, LabelStatement).map(|it| it.name.as_str())
}

pub fn is_empty_statement(&self) -> bool {
Expand Down Expand Up @@ -929,15 +926,7 @@ impl AstNode {

/// Returns true if the given token is an integer or float and zero.
pub fn is_zero(&self) -> bool {
match &self.stmt {
AstStatement::Literal(kind, ..) => match kind {
AstLiteral::Integer(0) => true,
AstLiteral::Real(val) => val == "0" || val == "0.0",
_ => false,
},

_ => false,
}
try_from!(self, AstLiteral).is_some_and(|it| it.is_zero())
}

pub fn is_binary_expression(&self) -> bool {
Expand All @@ -957,10 +946,7 @@ impl AstNode {
}

pub fn get_literal_integer_value(&self) -> Option<i128> {
match &self.stmt {
AstStatement::Literal(AstLiteral::Integer(value), ..) => Some(*value),
_ => None,
}
try_from!(self, AstLiteral).map(|it| it.get_literal_integer_value()).unwrap_or_default()
}

pub fn is_identifier(&self) -> bool {
Expand Down
48 changes: 47 additions & 1 deletion compiler/plc_ast/src/literals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::fmt::{Debug, Formatter};
use chrono::NaiveDate;

use crate::ast::AstNode;
use derive_more::TryInto;

macro_rules! impl_getters {
($type:ty, [$($name:ident),+], [$($out:ty),+]) => {
Expand All @@ -14,7 +15,8 @@ macro_rules! impl_getters {
}
}

#[derive(Clone, PartialEq)]
#[derive(Clone, PartialEq, TryInto)]
#[try_into(ref)]
pub enum AstLiteral {
/// a null literal used to initialize pointers
Null,
Expand All @@ -38,6 +40,37 @@ pub enum AstLiteral {
Array(Array),
}

// macro_rules! impl_try_from {
// (for $($id:ident),+) => {
// $(impl<'ast> TryFrom<&'ast AstNode> for &'ast $id {
// type Error = ();

// fn try_from(value: &'ast AstNode) -> Result<Self, Self::Error> {
// let crate::ast::AstStatement::Literal(AstLiteral::$id(inner)) = value.get_stmt() else {
// return Err(())
// };
// Ok(inner)
// }
// })*
// };
// (for $($id:ident, $p:path),+) => {
// $(impl<'ast> TryFrom<&'ast AstNode> for &'ast $id {
// type Error = ();

// fn try_from(value: &'ast AstNode) -> Result<Self, Self::Error> {
// let crate::ast::AstStatement::Literal($p(inner)) = value.get_stmt() else {
// return Err(())
// };
// Ok(inner)
// }
// })*
// };
// }

// impl_try_from!(for Date, DateAndTime, TimeOfDay, Time, Array);
// // XXX: String::try_from(..) is ambiguous between `AstLiteral::Real` and `AstStatement::Identifier`
// impl_try_from!(for i128, AstLiteral::Integer, String, AstLiteral::Real, bool, AstLiteral::Bool, StringValue, AstLiteral::String);

#[derive(Debug, Clone, PartialEq)]
pub struct Date {
year: i32,
Expand Down Expand Up @@ -271,6 +304,19 @@ impl AstLiteral {
| AstLiteral::DateAndTime { .. }
)
}

pub fn is_zero(&self) -> bool {
match self {
AstLiteral::Integer(0) => true,
AstLiteral::Real(val) => val == "0" || val == "0.0",
_ => false,
}
}

pub fn get_literal_integer_value(&self) -> Option<i128> {
let Self::Integer(val) = self else { return None };
Some(*val)
}
}

impl Debug for AstLiteral {
Expand Down
42 changes: 19 additions & 23 deletions src/codegen/generators/expression_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ use inkwell::{
};
use plc_ast::{
ast::{
flatten_expression_list, AstFactory, AstNode, AstStatement, DirectAccessType, Operator,
flatten_expression_list, Assignment, AstFactory, AstNode, AstStatement, DirectAccessType, Operator,
ReferenceAccess, ReferenceExpr,
},
literals::AstLiteral,
try_from,
};
use plc_diagnostics::diagnostics::{Diagnostic, INTERNAL_LLVM_ERROR};
use plc_source::source_location::SourceLocation;
Expand Down Expand Up @@ -533,16 +534,17 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> {
}

fn assign_output_value(&self, param_context: &CallParameterAssignment) -> Result<(), Diagnostic> {
match param_context.assignment_statement.get_stmt() {
AstStatement::OutputAssignment(data) | AstStatement::Assignment(data) => self
.generate_explicit_output_assignment(
param_context.parameter_struct,
param_context.function_name,
&data.left,
&data.right,
),
_ => self.generate_output_assignment(param_context),
}
let Some(data) = try_from!(param_context.assignment_statement, Assignment) else {
// implicit parameter assignment
return self.generate_output_assignment(param_context);
};

self.generate_explicit_output_assignment(
param_context.parameter_struct,
param_context.function_name,
&data.left,
&data.right,
)
}

fn generate_output_assignment(&self, param_context: &CallParameterAssignment) -> Result<(), Diagnostic> {
Expand Down Expand Up @@ -643,10 +645,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> {
}
// TODO: find a more reliable way to make sure if this is a call into a local action!!
PouIndexEntry::Action { .. }
if matches!(
operator.get_stmt(),
AstStatement::ReferenceExpr(ReferenceExpr { base: None, .. })
) =>
if try_from!(operator, ReferenceExpr).is_some_and(|it| it.base.is_none()) =>
{
// special handling for local actions, get the parameter from the function context
function_context
Expand Down Expand Up @@ -1052,17 +1051,14 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> {
&self,
param_context: &CallParameterAssignment,
) -> Result<Option<BasicValueEnum<'ink>>, Diagnostic> {
let parameter_value = match param_context.assignment_statement.get_stmt() {
// explicit call parameter: foo(param := value)
AstStatement::OutputAssignment(data) | AstStatement::Assignment(data) => {
self.generate_formal_parameter(param_context, &data.left, &data.right)?;
None
}
let Some(data) = try_from!(param_context.assignment_statement, Assignment) else {
// foo(x)
_ => self.generate_nameless_parameter(param_context)?,
return Ok(self.generate_nameless_parameter(param_context)?);
};

Ok(parameter_value)
// explicit call parameter: foo(param := value)
self.generate_formal_parameter(param_context, &data.left, &data.right)?;
Ok(None)
}

/// generates the appropriate value for the given expression where the expression
Expand Down
28 changes: 12 additions & 16 deletions src/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use plc_ast::{
control_statements::{AstControlStatement, ReturnStatement},
literals::{Array, AstLiteral, StringValue},
provider::IdProvider,
try_from,
};
use plc_source::source_location::SourceLocation;
use plc_util::convention::internal_type_name;
Expand Down Expand Up @@ -1799,7 +1800,7 @@ impl<'i> TypeAnnotator<'i> {
}

pub(crate) fn annotate_parameters(&mut self, p: &AstNode, type_name: &str) {
if !matches!(p.get_stmt(), AstStatement::Assignment(..) | AstStatement::OutputAssignment(..)) {
if try_from!(p, ast::Assignment).is_none() {
if let Some(effective_member_type) = self.index.find_effective_type_by_name(type_name) {
//update the type hint
self.annotation_map
Expand Down Expand Up @@ -2122,25 +2123,20 @@ fn accept_cast_string_literal(
literal: &AstNode,
) {
// check if we need to register an additional string-literal
match (cast_type.get_type_information(), literal.get_stmt()) {
(
DataTypeInformation::String { encoding: StringEncoding::Utf8, .. },
AstStatement::Literal(AstLiteral::String(StringValue { value, is_wide: is_wide @ true })),
)
| (
DataTypeInformation::String { encoding: StringEncoding::Utf16, .. },
AstStatement::Literal(AstLiteral::String(StringValue { value, is_wide: is_wide @ false })),
) => {
let Some(&AstLiteral::String(StringValue { ref value, is_wide })) = try_from!(literal, AstLiteral) else {
return;
};
match (cast_type.get_type_information(), is_wide) {
(DataTypeInformation::String { encoding: StringEncoding::Utf8, .. }, true)
| (DataTypeInformation::String { encoding: StringEncoding::Utf16, .. }, false) => {
// re-register the string-literal in the opposite encoding
if *is_wide {
literals.utf08.insert(value.to_string());
if is_wide {
literals.utf08.insert(value.into());
} else {
literals.utf16.insert(value.to_string());
literals.utf16.insert(value.into());
}
}
_ => {
//ignore
}
_ => (),
}
}

Expand Down

0 comments on commit cf35575

Please sign in to comment.