Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5680,6 +5680,46 @@ impl fmt::Display for FunctionBehavior {
}
}

/// These attributes describe the behavior of the function when called with a null argument.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum FunctionCalledOnNull {
CalledOnNullInput,
ReturnsNullOnNullInput,
Strict,
}

impl fmt::Display for FunctionCalledOnNull {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
FunctionCalledOnNull::CalledOnNullInput => write!(f, "CALLED ON NULL INPUT"),
FunctionCalledOnNull::ReturnsNullOnNullInput => write!(f, "RETURNS NULL ON NULL INPUT"),
FunctionCalledOnNull::Strict => write!(f, "STRICT"),
}
}
}

/// If it is safe for PostgreSQL to call the function from multiple threads at once
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum FunctionParallel {
Unsafe,
Restricted,
Safe,
}

impl fmt::Display for FunctionParallel {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
FunctionParallel::Unsafe => write!(f, "PARALLEL UNSAFE"),
FunctionParallel::Restricted => write!(f, "PARALLEL RESTRICTED"),
FunctionParallel::Safe => write!(f, "PARALLEL SAFE"),
}
}
}

#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
Expand All @@ -5700,7 +5740,7 @@ impl fmt::Display for FunctionDefinition {

/// Postgres specific feature.
///
/// See [Postgresdocs](https://www.postgresql.org/docs/15/sql-createfunction.html)
/// See [Postgres docs](https://www.postgresql.org/docs/15/sql-createfunction.html)
/// for more details
#[derive(Debug, Default, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
Expand All @@ -5710,6 +5750,10 @@ pub struct CreateFunctionBody {
pub language: Option<Ident>,
/// IMMUTABLE | STABLE | VOLATILE
pub behavior: Option<FunctionBehavior>,
/// CALLED ON NULL INPUT | RETURNS NULL ON NULL INPUT | STRICT
pub called_on_null: Option<FunctionCalledOnNull>,
/// PARALLEL { UNSAFE | RESTRICTED | SAFE }
pub parallel: Option<FunctionParallel>,
/// AS 'definition'
///
/// Note that Hive's `AS class_name` is also parsed here.
Expand All @@ -5728,6 +5772,12 @@ impl fmt::Display for CreateFunctionBody {
if let Some(behavior) = &self.behavior {
write!(f, " {behavior}")?;
}
if let Some(called_on_null) = &self.called_on_null {
write!(f, " {called_on_null}")?;
}
if let Some(parallel) = &self.parallel {
write!(f, " {parallel}")?;
}
if let Some(definition) = &self.as_ {
write!(f, " AS {definition}")?;
}
Expand Down
5 changes: 5 additions & 0 deletions src/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ define_keywords!(
INITIALLY,
INNER,
INOUT,
INPUT,
INPUTFORMAT,
INSENSITIVE,
INSERT,
Expand Down Expand Up @@ -498,6 +499,7 @@ define_keywords!(
OVERLAY,
OVERWRITE,
OWNED,
PARALLEL,
PARAMETER,
PARQUET,
PARTITION,
Expand Down Expand Up @@ -570,6 +572,7 @@ define_keywords!(
RESPECT,
RESTART,
RESTRICT,
RESTRICTED,
RESULT,
RESULTSET,
RETAIN,
Expand All @@ -589,6 +592,7 @@ define_keywords!(
ROW_NUMBER,
RULE,
RUN,
SAFE,
SAFE_CAST,
SAVEPOINT,
SCHEMA,
Expand Down Expand Up @@ -704,6 +708,7 @@ define_keywords!(
UNLOGGED,
UNNEST,
UNPIVOT,
UNSAFE,
UNSIGNED,
UNTIL,
UPDATE,
Expand Down
40 changes: 40 additions & 0 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3437,6 +3437,46 @@ impl<'a> Parser<'a> {
} else if self.parse_keyword(Keyword::VOLATILE) {
ensure_not_set(&body.behavior, "IMMUTABLE | STABLE | VOLATILE")?;
body.behavior = Some(FunctionBehavior::Volatile);
} else if self.parse_keywords(&[
Keyword::CALLED,
Keyword::ON,
Keyword::NULL,
Keyword::INPUT,
]) {
ensure_not_set(
&body.called_on_null,
"CALLED ON NULL INPUT | RETURNS NULL ON NULL INPUT | STRICT",
)?;
body.called_on_null = Some(FunctionCalledOnNull::CalledOnNullInput);
} else if self.parse_keywords(&[
Keyword::RETURNS,
Keyword::NULL,
Keyword::ON,
Keyword::NULL,
Keyword::INPUT,
]) {
ensure_not_set(
&body.called_on_null,
"CALLED ON NULL INPUT | RETURNS NULL ON NULL INPUT | STRICT",
)?;
body.called_on_null = Some(FunctionCalledOnNull::ReturnsNullOnNullInput);
} else if self.parse_keyword(Keyword::STRICT) {
ensure_not_set(
&body.called_on_null,
"CALLED ON NULL INPUT | RETURNS NULL ON NULL INPUT | STRICT",
)?;
body.called_on_null = Some(FunctionCalledOnNull::Strict);
} else if self.parse_keyword(Keyword::PARALLEL) {
ensure_not_set(&body.parallel, "PARALLEL { UNSAFE | RESTRICTED | SAFE }")?;
if self.parse_keyword(Keyword::UNSAFE) {
body.parallel = Some(FunctionParallel::Unsafe);
} else if self.parse_keyword(Keyword::RESTRICTED) {
body.parallel = Some(FunctionParallel::Restricted);
} else if self.parse_keyword(Keyword::SAFE) {
body.parallel = Some(FunctionParallel::Safe);
} else {
return self.expected("one of UNSAFE | RESTRICTED | SAFE", self.peek_token());
}
} else if self.parse_keyword(Keyword::RETURN) {
ensure_not_set(&body.return_, "RETURN")?;
body.return_ = Some(self.parse_expr()?);
Expand Down
48 changes: 46 additions & 2 deletions tests/sqlparser_postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3280,7 +3280,7 @@ fn parse_similar_to() {

#[test]
fn parse_create_function() {
let sql = "CREATE FUNCTION add(INTEGER, INTEGER) RETURNS INTEGER LANGUAGE SQL IMMUTABLE AS 'select $1 + $2;'";
let sql = "CREATE FUNCTION add(INTEGER, INTEGER) RETURNS INTEGER LANGUAGE SQL IMMUTABLE STRICT PARALLEL SAFE AS 'select $1 + $2;'";
assert_eq!(
pg_and_generic().verified_stmt(sql),
Statement::CreateFunction {
Expand All @@ -3295,6 +3295,8 @@ fn parse_create_function() {
params: CreateFunctionBody {
language: Some("SQL".into()),
behavior: Some(FunctionBehavior::Immutable),
called_on_null: Some(FunctionCalledOnNull::Strict),
parallel: Some(FunctionParallel::Safe),
as_: Some(FunctionDefinition::SingleQuotedDef(
"select $1 + $2;".into()
)),
Expand All @@ -3303,7 +3305,7 @@ fn parse_create_function() {
}
);

let sql = "CREATE OR REPLACE FUNCTION add(a INTEGER, IN b INTEGER = 1) RETURNS INTEGER LANGUAGE SQL IMMUTABLE RETURN a + b";
let sql = "CREATE OR REPLACE FUNCTION add(a INTEGER, IN b INTEGER = 1) RETURNS INTEGER LANGUAGE SQL IMMUTABLE RETURNS NULL ON NULL INPUT PARALLEL RESTRICTED RETURN a + b";
assert_eq!(
pg_and_generic().verified_stmt(sql),
Statement::CreateFunction {
Expand All @@ -3323,6 +3325,40 @@ fn parse_create_function() {
params: CreateFunctionBody {
language: Some("SQL".into()),
behavior: Some(FunctionBehavior::Immutable),
called_on_null: Some(FunctionCalledOnNull::ReturnsNullOnNullInput),
parallel: Some(FunctionParallel::Restricted),
return_: Some(Expr::BinaryOp {
left: Box::new(Expr::Identifier("a".into())),
op: BinaryOperator::Plus,
right: Box::new(Expr::Identifier("b".into())),
}),
..Default::default()
},
}
);

let sql = "CREATE OR REPLACE FUNCTION add(a INTEGER, IN b INTEGER = 1) RETURNS INTEGER LANGUAGE SQL STABLE CALLED ON NULL INPUT PARALLEL UNSAFE RETURN a + b";
assert_eq!(
pg_and_generic().verified_stmt(sql),
Statement::CreateFunction {
or_replace: true,
temporary: false,
name: ObjectName(vec![Ident::new("add")]),
args: Some(vec![
OperateFunctionArg::with_name("a", DataType::Integer(None)),
OperateFunctionArg {
mode: Some(ArgMode::In),
name: Some("b".into()),
data_type: DataType::Integer(None),
default_expr: Some(Expr::Value(Value::Number("1".parse().unwrap(), false))),
}
]),
return_type: Some(DataType::Integer(None)),
params: CreateFunctionBody {
language: Some("SQL".into()),
behavior: Some(FunctionBehavior::Stable),
called_on_null: Some(FunctionCalledOnNull::CalledOnNullInput),
parallel: Some(FunctionParallel::Unsafe),
return_: Some(Expr::BinaryOp {
left: Box::new(Expr::Identifier("a".into())),
op: BinaryOperator::Plus,
Expand All @@ -3348,6 +3384,8 @@ fn parse_create_function() {
params: CreateFunctionBody {
language: Some("plpgsql".into()),
behavior: None,
called_on_null: None,
parallel: None,
return_: None,
as_: Some(FunctionDefinition::DoubleDollarDef(
" BEGIN RETURN i + 1; END; ".into()
Expand All @@ -3358,6 +3396,12 @@ fn parse_create_function() {
);
}

#[test]
fn parse_incorrect_create_function_parallel() {
let sql = "CREATE FUNCTION add(INTEGER, INTEGER) RETURNS INTEGER LANGUAGE SQL PARALLEL BLAH AS 'select $1 + $2;'";
assert!(pg().parse_sql_statements(sql).is_err());
}

#[test]
fn parse_drop_function() {
let sql = "DROP FUNCTION IF EXISTS test_func";
Expand Down