Skip to content
This repository was archived by the owner on May 6, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .cargo/config.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[alias]
gha_clippy = "clippy --all-features -- -D warnings -W clippy::nursery"
gha_clippy = "clippy --all-features -- -D warnings -W clippy::pedantic -W clippy::nursery -W rust-2018-idioms"
gha_fmt = "fmt --all"
31 changes: 27 additions & 4 deletions src/aggregation.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::fmt::{self, Display};

use serde::{Deserialize, Serialize};
use sqlparser::ast;
use sqlparser::ast::{self};
use utoipa::{IntoParams, ToSchema};

use crate::{
Expand Down Expand Up @@ -102,6 +102,10 @@ impl Aggregation {
match &case_fold_identifier(unqualified_name)[..] {
"sum" => return only_column_arg(KoronFunction::Sum),
"count" => return only_column_arg(KoronFunction::Count),
"avg" => return only_column_arg(KoronFunction::Average),
"median" => return only_column_arg(KoronFunction::Median),
"variance" => return only_column_arg(KoronFunction::Variance),
"stddev" => return only_column_arg(KoronFunction::StandardDeviation),
_ => (),
}
}
Expand Down Expand Up @@ -174,13 +178,25 @@ pub enum KoronFunction {
/// The `count` aggregation function.
#[default]
Count,
/// The `average` aggregation function.
Average,
/// The `median` aggregation function.
Median,
/// The `variance` aggregation function.
Variance,
/// The `stddev` aggregation function.
StandardDeviation,
}

impl Display for KoronFunction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Sum => write!(f, "Sum"),
Self::Count => write!(f, "Count"),
Self::Sum => write!(f, "SUM"),
Self::Count => write!(f, "COUNT"),
Self::Average => write!(f, "AVG"),
Self::Median => write!(f, "MEDIAN"),
Self::Variance => write!(f, "VARIANCE"),
Self::StandardDeviation => write!(f, "STDDEV"),
}
}
}
Expand All @@ -191,7 +207,14 @@ mod tests {

#[test]
fn koron_fn_display() {
let cases = [(KoronFunction::Count, "Count"), (KoronFunction::Sum, "Sum")];
let cases = [
(KoronFunction::Count, "COUNT"),
(KoronFunction::Sum, "SUM"),
(KoronFunction::Variance, "VARIANCE"),
(KoronFunction::Median, "MEDIAN"),
(KoronFunction::Average, "AVG"),
(KoronFunction::StandardDeviation, "STDDEV"),
];
for (koron_fn, expected) in cases {
assert_eq!(koron_fn.to_string(), expected.to_string());
}
Expand Down
124 changes: 85 additions & 39 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,23 @@ mod tests {
let cases = [
("SUM(test_column_2)", KoronFunction::Sum),
("COUNT(test_column_2)", KoronFunction::Count),
("AVG(test_column_2)", KoronFunction::Average),
("MEDIAN(test_column_2)", KoronFunction::Median),
("VARIANCE(test_column_2)", KoronFunction::Variance),
("STDDEV(test_column_2)", KoronFunction::StandardDeviation),
];

for (projection, function) in cases {
let query = &format!("SELECT {projection} FROM test_db.test_schema.test_table_1");

let data_aggregation_query = if function == KoronFunction::Median {
None
} else {
Some(format!(
"SELECT CAST({projection} AS TEXT) FROM test_db.test_schema.test_table_1"
))
};

let expected = Ok(QueryMetadata {
table: sample_tab_ident(),
aggregation: Aggregation {
Expand All @@ -54,9 +66,13 @@ mod tests {
alias: None,
},
filter: None,
data_extraction_query: String::from(
"SELECT test_column_2 FROM test_db.test_schema.test_table_1",
),
data_aggregation_query,
});
assert_eq!(
QueryMetadata::parse(query),
QueryMetadata::parse(query, None),
expected,
"\nfailed for aggregation {projection}",
);
Expand All @@ -70,8 +86,14 @@ mod tests {
table: sample_tab_ident(),
aggregation: sample_sum(),
filter: None,
data_extraction_query: String::from(
"SELECT test_column_2 FROM test_db.test_schema.test_table_1",
),
data_aggregation_query: Some(String::from(
"SELECT CAST(SUM(test_column_2) AS TEXT) FROM test_db.test_schema.test_table_1",
)),
});
assert_eq!(QueryMetadata::parse(query), expected);
assert_eq!(QueryMetadata::parse(query, None), expected);
}

#[test]
Expand All @@ -81,8 +103,10 @@ mod tests {
table: sample_tab_ident(),
aggregation: sample_sum(),
filter: None,
data_extraction_query:String::from("SELECT test_column_2 FROM test_db.test_schema.test_table_1"),
data_aggregation_query: Some(String::from("SELECT CAST((((SUM(test_column_2)))) AS TEXT) FROM test_db.test_schema.test_table_1")),
});
assert_eq!(QueryMetadata::parse(query), expected);
assert_eq!(QueryMetadata::parse(query, None), expected);
}

#[test]
Expand All @@ -92,8 +116,10 @@ mod tests {
table: sample_tab_ident(),
aggregation: sample_sum(),
filter: None,
data_extraction_query:String::from("SELECT test_column_2 FROM test_db.test_schema.test_table_1"),
data_aggregation_query: Some(String::from("SELECT CAST(SUM((((test_column_2)))) AS TEXT) FROM test_db.test_schema.test_table_1")),
});
assert_eq!(QueryMetadata::parse(query), expected);
assert_eq!(QueryMetadata::parse(query, None), expected);
}

#[test]
Expand All @@ -107,8 +133,10 @@ mod tests {
alias: Some("s".to_string()),
},
filter: None,
data_extraction_query:String::from("SELECT test_column_2 FROM test_db.test_schema.test_table_1"),
data_aggregation_query: Some(String::from("SELECT CAST(SUM(test_column_2) AS TEXT) AS s FROM test_db.test_schema.test_table_1")),
});
assert_eq!(QueryMetadata::parse(query), expected);
assert_eq!(QueryMetadata::parse(query, None), expected);
}

#[test]
Expand All @@ -118,8 +146,10 @@ mod tests {
table: sample_tab_ident(),
aggregation: sample_sum(),
filter: None,
data_extraction_query:String::from("SELECT test_column_2 FROM test_db.test_schema.test_table_1"),
data_aggregation_query: Some(String::from("SELECT CAST(SUM(test_column_2) AS TEXT) FROM test_db.test_schema.test_table_1 AS t")),
});
assert_eq!(QueryMetadata::parse(query), expected);
assert_eq!(QueryMetadata::parse(query, None), expected);
}

#[test]
Expand All @@ -129,8 +159,14 @@ mod tests {
table: sample_tab_ident(),
aggregation: sample_sum(),
filter: None,
data_extraction_query: String::from(
"SELECT test_column_2 FROM test_db.test_schema.test_table_1",
),
data_aggregation_query: Some(String::from(
"SELECT CAST(sum(test_column_2) AS TEXT) FROM test_db.test_schema.test_table_1",
)),
});
assert_eq!(QueryMetadata::parse(query), expected);
assert_eq!(QueryMetadata::parse(query, None), expected);
}

#[test]
Expand All @@ -139,7 +175,7 @@ mod tests {
let expected = Err(unsupported!(
"unrecognized or unsupported function: \"SUM\".".to_string()
));
assert_eq!(QueryMetadata::parse(query), expected);
assert_eq!(QueryMetadata::parse(query, None), expected);
}

#[test]
Expand All @@ -153,8 +189,10 @@ mod tests {
alias: Some("s".to_string()),
},
filter: None,
data_extraction_query:String::from("SELECT test_column_2 FROM test_db.test_schema.test_table_1"),
data_aggregation_query: Some(String::from("SELECT CAST(SUM(test_column_2) AS TEXT) AS S FROM test_db.test_schema.test_table_1")),
});
assert_eq!(QueryMetadata::parse(query), expected);
assert_eq!(QueryMetadata::parse(query, None), expected);
}

#[test]
Expand All @@ -168,8 +206,10 @@ mod tests {
alias: Some("S".to_string()),
},
filter: None,
data_extraction_query:String::from("SELECT test_column_2 FROM test_db.test_schema.test_table_1"),
data_aggregation_query: Some(String::from("SELECT CAST(SUM(test_column_2) AS TEXT) AS \"S\" FROM test_db.test_schema.test_table_1")),
});
assert_eq!(QueryMetadata::parse(query), expected);
assert_eq!(QueryMetadata::parse(query, None), expected);
}

#[test]
Expand All @@ -186,7 +226,7 @@ mod tests {
the table that's listed in the FROM clause ({extracted_alias}).",
)));
assert_eq!(
QueryMetadata::parse(query),
QueryMetadata::parse(query, None),
expected,
"\nfailed for query {query:?}",
);
Expand All @@ -206,7 +246,7 @@ mod tests {
the table that's listed in the FROM clause (test_db.test_schema.test_table_1).",
)));
assert_eq!(
QueryMetadata::parse(query),
QueryMetadata::parse(query, None),
expected,
"\nfailed for query {query:?}",
);
Expand All @@ -225,7 +265,7 @@ mod tests {
the table that's listed in the FROM clause (t).",
)));
assert_eq!(
QueryMetadata::parse(query),
QueryMetadata::parse(query, None),
expected,
"\nfailed for query {query:?}",
);
Expand All @@ -238,22 +278,22 @@ mod tests {
let expected = Err(malformed_query!(
"sql parser error: Expected identifier, found: EOF".to_string()
));
assert_eq!(QueryMetadata::parse(query), expected);
assert_eq!(QueryMetadata::parse(query, None), expected);
}

#[test]
fn table_name_too_many_name_parts() {
let query = "SELECT SUM(test_column_2) FROM x.test_db.test_schema.test_table_1";
let expected = Err(internal!("found too many ident in table name (i.e., x.test_db.test_schema.test_table_1) in query AST.".to_string()));
assert_eq!(QueryMetadata::parse(query), expected);
assert_eq!(QueryMetadata::parse(query, None), expected);
}

#[test]
fn column_name_too_many_name_parts() {
let query = "SELECT SUM(x.test_db.test_schema.test_table_1.test_column_2) FROM test_db.test_schema.test_table_1";
let expected = Err(internal!("found too many ident in column name (i.e., x.test_db.test_schema.test_table_1.test_column_2)."
.to_string()));
assert_eq!(QueryMetadata::parse(query), expected);
assert_eq!(QueryMetadata::parse(query, None), expected);
}

#[test]
Expand All @@ -273,7 +313,7 @@ mod tests {
let query = &format!("SELECT {projection} FROM test_db.test_schema.test_table_1");
let expected = Err(malformed_query!(reason.to_string()));
assert_eq!(
QueryMetadata::parse(query),
QueryMetadata::parse(query, None),
expected,
"\nfailed for aggregation {projection}",
);
Expand Down Expand Up @@ -473,22 +513,6 @@ mod tests {
"SELECT MAX(test_column_2) FROM test_db.test_schema.test_table_1;",
"unrecognized or unsupported function: MAX."
),
(
"SELECT AVG(test_column_2) FROM test_db.test_schema.test_table_1;",
"unrecognized or unsupported function: AVG."
),
(
"SELECT STDDEV(test_column_2) FROM test_db.test_schema.test_table_1;",
"unrecognized or unsupported function: STDDEV."
),
(
"SELECT VARIANCE(test_column_2) FROM test_db.test_schema.test_table_1;",
"unrecognized or unsupported function: VARIANCE."
),
(
"SELECT MEDIAN(test_column_2) FROM test_db.test_schema.test_table_1;",
"unrecognized or unsupported function: MEDIAN."
),
(
"SELECT KTHELEMENT(test_column_2, 3) FROM test_db.test_schema.test_table_1;",
"unrecognized or unsupported function: KTHELEMENT."
Expand All @@ -498,7 +522,7 @@ mod tests {
for (query, reason) in cases {
let expected = Err(unsupported!(reason.to_string()));
assert_eq!(
QueryMetadata::parse(query),
QueryMetadata::parse(query, None),
expected,
"\nfailed for query {query:?}",
);
Expand Down Expand Up @@ -748,14 +772,36 @@ mod tests {
let query = &format!("{query} WHERE {selection}");
let mut aggregation = sample_sum();
aggregation.function = enum_fn;
let expected = Ok(QueryMetadata {
let expected_query = if &filter.column == "test_column_2" {
"SELECT test_column_2 FROM test_db.test_schema.test_table_1".to_string()
} else {
format!(
"SELECT test_column_2, {} FROM test_db.test_schema.test_table_1",
filter.column
)
};
let expected = QueryMetadata {
table: sample_tab_ident(),
aggregation,
filter: Some(filter),
});
filter: Some(filter.clone()),
data_extraction_query: expected_query,
data_aggregation_query: None,
};
let result = QueryMetadata::parse(query, None).unwrap();
assert_eq!(
result.aggregation, expected.aggregation,
"\nfailed for selection {selection:?}",
);
assert_eq!(
result.table, expected.table,
"\nfailed for selection {selection:?}",
);
assert_eq!(
result.filter, expected.filter,
"\nfailed for selection {selection:?}",
);
assert_eq!(
QueryMetadata::parse(query),
expected,
result.data_extraction_query, expected.data_extraction_query,
"\nfailed for selection {selection:?}",
);
}
Expand Down
Loading