diff --git a/CMakeLists.txt b/CMakeLists.txt index ceee647d28..ba13691398 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -53,10 +53,6 @@ set(CMAKE_CXX_STANDARD 20 CACHE STRING "C++ ISO Standard version") if (CMAKE_CXX_STANDARD LESS 20) message(FATAL_ERROR "C++ 2020 ISO Standard or higher is required to build SeQuant") endif () -# C++20 is only configurable via compile features with cmake 3.12 and older -if (CMAKE_CXX_STANDARD EQUAL 20 AND CMAKE_VERSION VERSION_LESS 3.12.0) - cmake_minimum_required(VERSION 3.12.0) -endif () set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF CACHE BOOL "Whether to use extensions of C++ ISO Standard version") @@ -248,6 +244,9 @@ include(FindOrFetchUtfcpp) # Eigen include(FindOrFetchEigen) +# cpp-peglib +include(FindOrFetchCppPeglib) + # embedded bliss-0.73 add_library(SeQuant-bliss SeQuant/external/bliss/defs.cc @@ -341,8 +340,16 @@ set(SeQuant_symb_src SeQuant/core/io/serialization/v1/ast.hpp SeQuant/core/io/serialization/v1/ast_conversions.hpp SeQuant/core/io/serialization/v1/deserialize.cpp + SeQuant/core/io/serialization/v1/error.cpp + SeQuant/core/io/serialization/v1/error.hpp SeQuant/core/io/serialization/v1/semantic_actions.hpp SeQuant/core/io/serialization/v1/serialize.cpp + SeQuant/core/io/serialization/v2/ast_conversions.cpp + SeQuant/core/io/serialization/v2/ast_conversions.hpp + SeQuant/core/io/serialization/v2/deserialize.cpp + SeQuant/core/io/serialization/v2/error.cpp + SeQuant/core/io/serialization/v2/error.hpp + SeQuant/core/io/serialization/v2/serialize.cpp SeQuant/core/ranges.hpp SeQuant/core/rational.hpp SeQuant/core/runtime.cpp @@ -570,6 +577,7 @@ target_link_libraries(SeQuant-symb Threads::Threads libperm::libperm $ + $ Eigen3::Eigen) if (Boost_IS_MODULARIZED) target_link_libraries(SeQuant-symb PUBLIC diff --git a/SeQuant/core/io/serialization/serialization.cpp b/SeQuant/core/io/serialization/serialization.cpp index fc14fac930..0979884245 100644 --- a/SeQuant/core/io/serialization/serialization.cpp +++ b/SeQuant/core/io/serialization/serialization.cpp @@ -1,41 +1,115 @@ #include #include +#include #include +#include #include namespace sequant::io::serialization { -SerializationError::SerializationError(std::size_t offset, std::size_t length, - std::string message) - : Exception(std::move(message)), offset(offset), length(length) {} +SerializationError::SerializationError(std::string message) + : Exception(std::move(message)) {} -#define SEQUANT_RESOLVE_DESERIALIZATION_FUNC(stringType, ExprType) \ +template +concept v1_can_deserialize = requires(StringType input) { + serialization::v1::from_string(input); +}; + +template +concept v2_can_deserialize = + requires(StringType input, DeserializationOptions opts) { + serialization::v2::from_string(input); + }; + +static_assert(!v2_can_deserialize); + +template +ExprType from_string_indirection(StringType input, + const DeserializationOptions &options) { + // Note: We need this indirection as the if constexpr construct only works (in + // the way we need) if the condition depends on a template parameter. If it + // didn't, we'd get a bunch of "call to deleted function" errors as the body + // of the if is checked even though the if is false. + switch (options.syntax) { + case SerializationSyntax::V1: + if constexpr (v1_can_deserialize) { + return serialization::v1::from_string(input, options); + } else { + throw Exception( + "Deserialization of this type is not supported with syntax V1"); + } + case SerializationSyntax::V2: + if constexpr (v2_can_deserialize) { + return serialization::v2::from_string(input, options); + } else { + throw Exception( + "Deserialization of this type is not supported with syntax V2"); + } + } + + SEQUANT_UNREACHABLE; +} + +#define SEQUANT_RESOLVE_DESERIALIZATION_FUNC(StringType, ExprType) \ template <> \ - ExprType from_string(stringType input, \ + ExprType from_string(StringType input, \ const DeserializationOptions &options) { \ - switch (options.syntax) { \ - case SerializationSyntax::V1: \ - return serialization::v1::from_string(input, options); \ - } \ - \ - SEQUANT_UNREACHABLE; \ + return from_string_indirection(input, options); \ } SEQUANT_RESOLVE_DESERIALIZATION_FUNC(std::wstring_view, ExprPtr) SEQUANT_RESOLVE_DESERIALIZATION_FUNC(std::string_view, ExprPtr) SEQUANT_RESOLVE_DESERIALIZATION_FUNC(std::wstring_view, ResultExpr) SEQUANT_RESOLVE_DESERIALIZATION_FUNC(std::string_view, ResultExpr) +SEQUANT_RESOLVE_DESERIALIZATION_FUNC(std::string_view, Constant); +SEQUANT_RESOLVE_DESERIALIZATION_FUNC(std::wstring_view, Constant); +SEQUANT_RESOLVE_DESERIALIZATION_FUNC(std::string_view, Variable); +SEQUANT_RESOLVE_DESERIALIZATION_FUNC(std::wstring_view, Variable); +SEQUANT_RESOLVE_DESERIALIZATION_FUNC(std::string_view, Tensor); +SEQUANT_RESOLVE_DESERIALIZATION_FUNC(std::wstring_view, Tensor); +SEQUANT_RESOLVE_DESERIALIZATION_FUNC(std::string_view, Product); +SEQUANT_RESOLVE_DESERIALIZATION_FUNC(std::wstring_view, Product); +SEQUANT_RESOLVE_DESERIALIZATION_FUNC(std::string_view, Sum); +SEQUANT_RESOLVE_DESERIALIZATION_FUNC(std::wstring_view, Sum); + +template +concept v1_can_serialize = requires(const ExprType &expr) { + { serialization::v1::to_string(expr) } -> std::convertible_to; +}; + +template +concept v2_can_serialize = requires(const ExprType &expr) { + { serialization::v2::to_string(expr) } -> std::convertible_to; +}; + +template +std::wstring to_string_indirection(const Arg &arg, + const SerializationOptions &options) { + switch (options.syntax) { + case SerializationSyntax::V1: + if constexpr (v1_can_serialize) { + return serialization::v1::to_string(arg, options); + } else { + throw Exception( + "Serialization of this type is not supported with syntax V1"); + } + case SerializationSyntax::V2: + if constexpr (v2_can_serialize) { + return serialization::v2::to_string(arg, options); + } else { + throw Exception( + "Serialization of this type is not supported with syntax V2"); + } + } + + SEQUANT_UNREACHABLE; +} #define SEQUANT_RESOLVE_SERIALIZE_FUNC(argType) \ std::wstring to_string(const argType &arg, \ const SerializationOptions &options) { \ - switch (options.syntax) { \ - case SerializationSyntax::V1: \ - return serialization::v1::to_string(arg, options); \ - } \ - \ - SEQUANT_UNREACHABLE; \ + return to_string_indirection(arg, options); \ } SEQUANT_RESOLVE_SERIALIZE_FUNC(ResultExpr) diff --git a/SeQuant/core/io/serialization/serialization.hpp b/SeQuant/core/io/serialization/serialization.hpp index d4d29604c6..6fcaff7d3e 100644 --- a/SeQuant/core/io/serialization/serialization.hpp +++ b/SeQuant/core/io/serialization/serialization.hpp @@ -14,11 +14,7 @@ namespace sequant::io::serialization { struct SerializationError : Exception { - std::size_t offset; - std::size_t length; - - SerializationError(std::size_t offset, std::size_t length, - std::string message); + SerializationError(std::string message); }; /// Specifies the syntax of the textual input/representation to use. All @@ -32,6 +28,7 @@ struct SerializationError : Exception { /// and support for it may be removed in future versions. enum class SerializationSyntax { V1, + V2, Latest = V1 }; @@ -83,6 +80,11 @@ SEQUANT_DECLARE_DESERIALIZATION_FUNC; SEQUANT_DECLARE_DESERIALIZATION_FUNC_SPECIALIZATION(ExprPtr); SEQUANT_DECLARE_DESERIALIZATION_FUNC_SPECIALIZATION(ResultExpr); +SEQUANT_DECLARE_DESERIALIZATION_FUNC_SPECIALIZATION(Constant); +SEQUANT_DECLARE_DESERIALIZATION_FUNC_SPECIALIZATION(Variable); +SEQUANT_DECLARE_DESERIALIZATION_FUNC_SPECIALIZATION(Tensor); +SEQUANT_DECLARE_DESERIALIZATION_FUNC_SPECIALIZATION(Product); +SEQUANT_DECLARE_DESERIALIZATION_FUNC_SPECIALIZATION(Sum); #define SEQUANT_DECLARE_SERIALIZATION_FUNC \ std::wstring to_string(const ResultExpr &expr, \ @@ -125,6 +127,23 @@ SEQUANT_DECLARE_DESERIALIZATION_FUNC_SPECIALIZATION(ResultExpr); SEQUANT_DECLARE_SERIALIZATION_FUNC } // namespace v1 +namespace v2 { +SEQUANT_DECLARE_DESERIALIZATION_FUNC; + +// SEQUANT_DECLARE_DESERIALIZATION_FUNC_SPECIALIZATION(ExprPtr); +// SEQUANT_DECLARE_DESERIALIZATION_FUNC_SPECIALIZATION(ResultExpr); +SEQUANT_DECLARE_DESERIALIZATION_FUNC_SPECIALIZATION(Constant); +// SEQUANT_DECLARE_DESERIALIZATION_FUNC_SPECIALIZATION(Variable); +// SEQUANT_DECLARE_DESERIALIZATION_FUNC_SPECIALIZATION(Tensor); +// SEQUANT_DECLARE_DESERIALIZATION_FUNC_SPECIALIZATION(Product); +// SEQUANT_DECLARE_DESERIALIZATION_FUNC_SPECIALIZATION(Sum); + +// Dummy +inline void to_string() {} + +// SEQUANT_DECLARE_SERIALIZATION_FUNC +} // namespace v2 + #undef SEQUANT_DECLARE_DESERIALIZATION_FUNC #undef SEQUANT_DECLARE_DESERIALIZATION_FUNC_SPECIALIZATION #undef SEQUANT_DECLARE_SERIALIZATION_FUNC diff --git a/SeQuant/core/io/serialization/v1/ast.hpp b/SeQuant/core/io/serialization/v1/ast.hpp index 55d5c6124a..050ff278f0 100644 --- a/SeQuant/core/io/serialization/v1/ast.hpp +++ b/SeQuant/core/io/serialization/v1/ast.hpp @@ -2,8 +2,8 @@ // Created by Robert Adam on 2023-09-21 // -#ifndef SEQUANT_CORE_PARSE_V1_AST_HPP -#define SEQUANT_CORE_PARSE_V1_AST_HPP +#ifndef SEQUANT_CORE_IO_SERIALIZATION_V1_AST_HPP +#define SEQUANT_CORE_IO_SERIALIZATION_V1_AST_HPP #define BOOST_SPIRIT_X3_UNICODE #include @@ -154,4 +154,4 @@ BOOST_FUSION_ADAPT_STRUCT(sequant::io::serialization::v1::ast::Sum, summands); BOOST_FUSION_ADAPT_STRUCT(sequant::io::serialization::v1::ast::ResultExpr, lhs, rhs); -#endif // SEQUANT_CORE_PARSE_AST_V1_HPP +#endif // SEQUANT_CORE_IO_SERIALIZATION_V1_AST_HPP diff --git a/SeQuant/core/io/serialization/v1/ast_conversions.hpp b/SeQuant/core/io/serialization/v1/ast_conversions.hpp index cae45a75f4..9353488bc8 100644 --- a/SeQuant/core/io/serialization/v1/ast_conversions.hpp +++ b/SeQuant/core/io/serialization/v1/ast_conversions.hpp @@ -2,13 +2,14 @@ // Created by Robert Adam on 2023-09-22 // -#ifndef SEQUANT_CORE_PARSE_AST_CONVERSIONS_HPP -#define SEQUANT_CORE_PARSE_AST_CONVERSIONS_HPP +#ifndef SEQUANT_CORE_IO_SERIALIZATION_V1_AST_CONVERSIONS_HPP +#define SEQUANT_CORE_IO_SERIALIZATION_V1_AST_CONVERSIONS_HPP #include #include #include #include +#include #include #include #include @@ -49,15 +50,14 @@ Index to_index(const io::serialization::v1::ast::Index &index, protoIndices.push_back(Index(std::move(label), std::move(space))); } catch (const IndexSpace::bad_key &) { auto [offset, length] = get_pos(current, position_cache, begin); - throw SerializationError(offset, length, - "Unknown index space '" + toUtf8(current.label) + - "' in proto index specification"); + throw Error(offset, length, + "Unknown index space '" + toUtf8(current.label) + + "' in proto index specification"); } catch (const Exception &e) { auto [offset, length] = get_pos(current, position_cache, begin); - throw SerializationError(offset, length, - "Invalid index '" + toUtf8(current.label) + "_" + - std::to_string(current.id) + ": " + - e.what()); + throw Error(offset, length, + "Invalid index '" + toUtf8(current.label) + "_" + + std::to_string(current.id) + ": " + e.what()); } } @@ -67,16 +67,14 @@ Index to_index(const io::serialization::v1::ast::Index &index, return Index(std::move(space), index.label.id, std::move(protoIndices)); } catch (const IndexSpace::bad_key &e) { auto [offset, length] = get_pos(index.label, position_cache, begin); - throw SerializationError(offset, length, - "Unknown index space '" + - toUtf8(index.label.label) + - "' in index specification"); + throw Error(offset, length, + "Unknown index space '" + toUtf8(index.label.label) + + "' in index specification"); } catch (const Exception &e) { auto [offset, length] = get_pos(index.label, position_cache, begin); - throw SerializationError(offset, length, - "Invalid index '" + toUtf8(index.label.label) + - "_" + std::to_string(index.label.id) + ": " + - e.what()); + throw Error(offset, length, + "Invalid index '" + toUtf8(index.label.label) + "_" + + std::to_string(index.label.id) + ": " + e.what()); } } @@ -136,8 +134,7 @@ Symmetry to_perm_symmetry(char c, std::size_t offset, const Iterator &, return Symmetry::Nonsymm; } - throw SerializationError( - offset, 1, std::string("Invalid symmetry specifier '") + c + "'"); + throw Error(offset, 1, std::string("Invalid symmetry specifier '") + c + "'"); } template @@ -159,8 +156,8 @@ BraKetSymmetry to_braket_symmetry(char c, std::size_t offset, const Iterator &, return BraKetSymmetry::Nonsymm; } - throw SerializationError( - offset, 1, std::string("Invalid BraKet symmetry specifier '") + c + "'"); + throw Error(offset, 1, + std::string("Invalid BraKet symmetry specifier '") + c + "'"); } template @@ -179,9 +176,8 @@ ColumnSymmetry to_column_symmetry(char c, std::size_t offset, const Iterator &, return ColumnSymmetry::Nonsymm; } - throw SerializationError( - offset, 1, - std::string("Invalid particle symmetry specifier '") + c + "'"); + throw Error(offset, 1, + std::string("Invalid particle symmetry specifier '") + c + "'"); } template @@ -412,11 +408,11 @@ ResultExpr ast_to_result(const io::serialization::v1::ast::ResultExpr &result, return {std::move(lhs.as()), std::move(rhs)}; } else { auto [offset, length] = get_pos(result.lhs, position_cache, begin); - throw SerializationError( - offset, length, "LHS of a ResultExpr must be a Tensor or a Variable"); + throw Error(offset, length, + "LHS of a ResultExpr must be a Tensor or a Variable"); } } } // namespace sequant::io::serialization::v1::transform -#endif // SEQUANT_CORE_PARSE_AST_CONVERSIONS_HPP +#endif // SEQUANT_CORE_IO_SERIALIZATION_V1_AST_CONVERSIONS_HPP diff --git a/SeQuant/core/io/serialization/v1/deserialize.cpp b/SeQuant/core/io/serialization/v1/deserialize.cpp index 9b55ae7512..c9ec271034 100644 --- a/SeQuant/core/io/serialization/v1/deserialize.cpp +++ b/SeQuant/core/io/serialization/v1/deserialize.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -184,10 +185,9 @@ struct ErrorHandler { void operator()(Iterator where, std::string expected) const { std::size_t offset = std::distance(begin, where); - throw SerializationError(offset, 1, - std::string("Parse failure at offset ") + - std::to_string(offset) + ": expected " + - expected); + throw Error(offset, 1, + std::string("Parse failure at offset ") + + std::to_string(offset) + ": expected " + expected); } }; @@ -210,15 +210,15 @@ AST parse(const StartRule &start, std::wstring_view input, if (!success) { // Normally, this shouldn't happen as any error should itself throw a - // SerializationError already - throw SerializationError( - 0, input.size(), "Parsing was unsuccessful for an unknown reason"); + // Error already + throw Error(0, input.size(), + "Parsing was unsuccessful for an unknown reason"); } if (begin != input.end()) { // This should also not happen as the parser requires matching EOI - throw SerializationError(std::distance(input.begin(), begin), - std::distance(begin, input.end()), - "Couldn't parse the entire input"); + throw Error(std::distance(input.begin(), begin), + std::distance(begin, input.end()), + "Couldn't parse the entire input"); } } catch ([[maybe_unused]] const boost::spirit::x3::expectation_failure< iterator_type> &e) { diff --git a/SeQuant/core/io/serialization/v1/error.cpp b/SeQuant/core/io/serialization/v1/error.cpp new file mode 100644 index 0000000000..e383a4d2b0 --- /dev/null +++ b/SeQuant/core/io/serialization/v1/error.cpp @@ -0,0 +1,8 @@ +#include + +namespace sequant::io::serialization::v1 { + +Error::Error(std::size_t offset, std::size_t length, std::string msg) + : SerializationError(std::move(msg)), offset(offset), length(length) {} + +} // namespace sequant::io::serialization::v1 diff --git a/SeQuant/core/io/serialization/v1/error.hpp b/SeQuant/core/io/serialization/v1/error.hpp new file mode 100644 index 0000000000..b3acdd5daa --- /dev/null +++ b/SeQuant/core/io/serialization/v1/error.hpp @@ -0,0 +1,20 @@ +#ifndef SEQUANT_CORE_IO_SERIALIZATION_V1_ERROR_HPP +#define SEQUANT_CORE_IO_SERIALIZATION_V1_ERROR_HPP + +#include + +#include +#include + +namespace sequant::io::serialization::v1 { + +struct Error : SerializationError { + std::size_t offset; + std::size_t length; + + Error(std::size_t offset, std::size_t length, std::string msg); +}; + +} // namespace sequant::io::serialization::v1 + +#endif // SEQUANT_CORE_IO_SERIALIZATION_V1_ERROR_HPP diff --git a/SeQuant/core/io/serialization/v1/semantic_actions.hpp b/SeQuant/core/io/serialization/v1/semantic_actions.hpp index 26dab219a7..f98a7dbe02 100644 --- a/SeQuant/core/io/serialization/v1/semantic_actions.hpp +++ b/SeQuant/core/io/serialization/v1/semantic_actions.hpp @@ -2,8 +2,8 @@ // Created by Robert Adam on 2023-09-21 // -#ifndef SEQUANT_CORE_PARSE_SEMANTIC_ACTIONS_V1_HPP -#define SEQUANT_CORE_PARSE_SEMANTIC_ACTIONS_V1_HPP +#ifndef SEQUANT_CORE_IO_SERIALIZATION_V1_SEMANTIC_ACTIONS_HPP +#define SEQUANT_CORE_IO_SERIALIZATION_V1_SEMANTIC_ACTIONS_HPP #include #include @@ -72,4 +72,4 @@ struct process_addend { } // namespace sequant::io::serialization::v1::actions -#endif // SEQUANT_CORE_PARSE_SEMANTIC_ACTIONS_V1_HPP +#endif // SEQUANT_CORE_IO_SERIALIZATION_V1_SEMANTIC_ACTIONS_HPP diff --git a/SeQuant/core/io/serialization/v2/ast_conversions.cpp b/SeQuant/core/io/serialization/v2/ast_conversions.cpp new file mode 100644 index 0000000000..7e95c07fc3 --- /dev/null +++ b/SeQuant/core/io/serialization/v2/ast_conversions.cpp @@ -0,0 +1,69 @@ +#include +#include +#include +#include +#include + +#include + +#include + +namespace sequant::io::serialization::v2 { + +std::int64_t to_int(const peg::Ast &ast) { + SEQUANT_ASSERT(ast.name == "Integer"); + + return string_to(ast.token); +} + +double to_double(const peg::Ast &ast) { + SEQUANT_ASSERT(ast.name == "Float"); + + return string_to(ast.token); +} + +double to_real(const peg::Ast &ast) { + SEQUANT_ASSERT(ast.name == "Real"); + SEQUANT_ASSERT(ast.nodes.size() == 1); + SEQUANT_ASSERT(ast.nodes[0]); + + if (ast.nodes[0]->name == "Integer") { + return to_int(*ast.nodes[0]); + } else { + return to_double(*ast.nodes[0]); + } +} + +Constant::scalar_type to_complex(const peg::Ast &ast) { + SEQUANT_ASSERT(ast.name == "Complex"); + SEQUANT_ASSERT(ast.nodes.size() == 1); + SEQUANT_ASSERT(ast.nodes[0]); + + if (ast.nodes[0]->name == "Imaginary") { + const auto &imag_node = *ast.nodes[0]; + SEQUANT_ASSERT(ast.nodes.size() == 1); + SEQUANT_ASSERT(ast.nodes[0]); + return Constant::scalar_type(0, to_real(*imag_node.nodes[0])); + } else { + return Constant::scalar_type(to_real(*ast.nodes[0]), 0); + } +} + +template <> +Constant ast_to(const peg::Ast &ast, const DeserializationOptions &) { + SEQUANT_ASSERT(ast.name == "Constant"); + SEQUANT_ASSERT(ast.nodes.size() == 1); + SEQUANT_ASSERT(ast.nodes[0]); + + return Constant(to_complex(*ast.nodes[0])); +} + +template <> +ExprPtr ast_to(const peg::Ast &ast, + const DeserializationOptions &options) { + (void)ast; + (void)options; + return {}; +} + +} // namespace sequant::io::serialization::v2 diff --git a/SeQuant/core/io/serialization/v2/ast_conversions.hpp b/SeQuant/core/io/serialization/v2/ast_conversions.hpp new file mode 100644 index 0000000000..e4fca1524e --- /dev/null +++ b/SeQuant/core/io/serialization/v2/ast_conversions.hpp @@ -0,0 +1,25 @@ +#ifndef SEQUANT_CORE_IO_SERIALIZATION_V2_AST_CONVERSIONS_HPP +#define SEQUANT_CORE_IO_SERIALIZATION_V2_AST_CONVERSIONS_HPP + +#include +#include +#include + +#include + +namespace sequant::io::serialization::v2 { + +template +T ast_to(const peg::Ast &ast, const DeserializationOptions &options) = delete; + +template <> +Constant ast_to(const peg::Ast &ast, + const DeserializationOptions &options); + +template <> +ExprPtr ast_to(const peg::Ast &ast, + const DeserializationOptions &options); + +} // namespace sequant::io::serialization::v2 + +#endif // SEQUANT_CORE_IO_SERIALIZATION_V2_AST_CONVERSIONS_HPP diff --git a/SeQuant/core/io/serialization/v2/deserialize.cpp b/SeQuant/core/io/serialization/v2/deserialize.cpp new file mode 100644 index 0000000000..49b27eddd5 --- /dev/null +++ b/SeQuant/core/io/serialization/v2/deserialize.cpp @@ -0,0 +1,93 @@ +#include +#include +#include +#include +#include +#include + +#include + +#include + +namespace sequant::io::serialization::v2 { + +// Defines a variable peg_serialization_grammar that contains the definition of +// the grammar (as a string_view) +#include "peg_grammar.ipp" + +static peg::parser &get_parser(std::string_view start_rule) { + static thread_local bool initialized = false; + static thread_local std::string last_start; + static thread_local peg::parser parser; + + peg::Log logger = [](std::size_t line, std::size_t column, + const std::string &msg, const std::string &rule) { + throw Error(line, column, rule, + "Deserialization failed at line " + std::to_string(line) + ":" + + std::to_string(column) + ": " + msg + " (" + rule + ")"); + }; + + if (!initialized) { + initialized = true; + // This first logger will be notified in case there are issues with parsing + // the grammar itself + parser.set_logger([](std::size_t line, std::size_t column, + const std::string &msg, const std::string &rule) { + throw Error(line, column, rule, + "Input grammar invalid at line " + std::to_string(line) + + ":" + std::to_string(column) + ": " + msg + " (" + rule + + ")"); + }); + + // Note: we have to use ResultExpression in order to avoid being notified + // about this rule not being referenced + parser.load_grammar(peg_serialization_grammar, "ResultExpression"); + last_start = "ResultExpression"; + + // This is the logger that we want to use from now on (for SeQuant syntax) + parser.set_logger(logger); + + parser.enable_ast(); + } + + if (last_start != start_rule) { + // Ability to change the start rule on-the-fly is implemented in + // https://github.com/yhirose/cpp-peglib/pull/332 + // Until that is merged, we have to reload the grammar every time + // we want to use a different start rule. + + // We assume that the grammar is valid (it has been loaded before) and in + // order to ignore 'rule xy not referenced' logs (which our logger turns + // into exceptions) we uninstall the logger before reloading the grammar + parser.set_logger(peg::Log{}); + parser.load_grammar(peg_serialization_grammar, start_rule); + parser.set_logger(logger); + parser.enable_ast(); + } + + return parser; +} + +#define SEQUANT_DESERIALIZATION_SPECIALIZATION(Type, Rule) \ + template <> \ + Type from_string(std::string_view input, \ + const DeserializationOptions &options) { \ + auto &parser = get_parser(Rule); \ + \ + std::shared_ptr ast; \ + parser.parse(input, ast); \ + \ + SEQUANT_ASSERT(ast); \ + \ + return ast_to(*ast, options); \ + } \ + template <> \ + Type from_string(std::wstring_view input, \ + const DeserializationOptions &options) { \ + return v2::from_string(toUtf8(input), options); \ + } + +SEQUANT_DESERIALIZATION_SPECIALIZATION(Constant, "Constant") +SEQUANT_DESERIALIZATION_SPECIALIZATION(ExprPtr, "Expression") + +} // namespace sequant::io::serialization::v2 diff --git a/SeQuant/core/io/serialization/v2/error.cpp b/SeQuant/core/io/serialization/v2/error.cpp new file mode 100644 index 0000000000..ce163ba08c --- /dev/null +++ b/SeQuant/core/io/serialization/v2/error.cpp @@ -0,0 +1,12 @@ +#include + +namespace sequant::io::serialization::v2 { + +Error::Error(std::size_t line, std::size_t column, std::string rule, + std::string msg) + : SerializationError(std::move(msg)), + line(line), + column(column), + rule(std::move(rule)) {} + +} // namespace sequant::io::serialization::v2 diff --git a/SeQuant/core/io/serialization/v2/error.hpp b/SeQuant/core/io/serialization/v2/error.hpp new file mode 100644 index 0000000000..68a9c595ab --- /dev/null +++ b/SeQuant/core/io/serialization/v2/error.hpp @@ -0,0 +1,22 @@ +#ifndef SEQUANT_CORE_IO_SERIALIZATION_V2_ERROR_HPP +#define SEQUANT_CORE_IO_SERIALIZATION_V2_ERROR_HPP + +#include + +#include +#include + +namespace sequant::io::serialization::v2 { + +struct Error : SerializationError { + std::size_t line; + std::size_t column; + std::string rule; + + Error(std::size_t line, std::size_t column, std::string rule, + std::string msg); +}; + +} // namespace sequant::io::serialization::v2 + +#endif // SEQUANT_CORE_IO_SERIALIZATION_V2_ERROR_HPP diff --git a/SeQuant/core/io/serialization/v2/peg_grammar.ipp b/SeQuant/core/io/serialization/v2/peg_grammar.ipp new file mode 100644 index 0000000000..15d651f624 --- /dev/null +++ b/SeQuant/core/io/serialization/v2/peg_grammar.ipp @@ -0,0 +1,60 @@ +using namespace std::literals; +constexpr static std::string_view peg_serialization_grammar = R"( + +Integer <- < [0-9]+ > +Float <- < Integer? '.' Integer > +Real <- Float / Integer +Imaginary <- Real 'i' +Complex <- Imaginary / Real +Constant <- Complex + +# This is a very liberal rule that simply tries to exclude the most important +# non-ID characters (in ASCII range) and then allows everything else +Identifier <- ( [^\u0000-\u002f\u003a-\u0040\u005b-\u0060\u007b-\u00a0] / '_' )+ + +Variable <- < Identifier ('^*')? > + +Index <- < Identifier > + +Indices <- ','* Index (','+ Index)* + +IndexGroups <- ';'* Indices (';'+ Indices)* ';'* + +PredefinedSymm <- 'A' / 'S' / 'N' / 'bkS' / 'bkC' / 'bkN' / 'pS' / 'pN' + +Symmetry <- PredefinedSymm + +Tensor <- Identifier ('^*')? '[' IndexGroups ']' ( ':' Symmetry (',' Symmetry)* )? + +Statistic <- 'F' / 'B' + +NormalOrderedOperator <- Identifier '{' IndexGroups '}' (':' Statistic)? + +Function <- '~' Identifier '(' [^)]* ')' + + +Nullary <- '(' Expression ')' / Function / Constant / Tensor / NormalOrderedOperator / Variable + +UnaryOperator <- '+' / '-' + +Unary <- UnaryOperator Nullary + +BinaryOperator <- '+' / '-' / '*' / '/' + +ExprAtom <- Unary / Nullary + +InfixExpression <- ExprAtom ( BinaryOperator ExprAtom )* { + precedence + L - + + L * / +} + +Expression <- ExprAtom Expression / InfixExpression Expression? + +Result <- Tensor / Variable + +ResultExpression <- Result '=' Expression + +%whitespace <- [ \t\r\n]* + +)"sv; diff --git a/SeQuant/core/io/serialization/v2/serialize.cpp b/SeQuant/core/io/serialization/v2/serialize.cpp new file mode 100644 index 0000000000..e69de29bb2 diff --git a/SeQuant/core/utility/conversion.hpp b/SeQuant/core/utility/conversion.hpp index 7af26c70b5..ad4c8f9806 100644 --- a/SeQuant/core/utility/conversion.hpp +++ b/SeQuant/core/utility/conversion.hpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -55,7 +56,7 @@ T string_to_impl(std::string_view str, Arg &&arg) { } // namespace template -concept string_to_supports = +concept from_chars_supports = requires(const char *c, T &v) { std::from_chars(c, c + 1, v); }; /// Converts the provided string to the desired integral type. @@ -74,7 +75,7 @@ concept string_to_supports = /// the parsed value can't be represented as a T. template T string_to(std::string_view str, int base = 10) { - static_assert(string_to_supports, + static_assert(from_chars_supports, "Your C++ standard library is missing a std::from_chars " "implementation for this integral type"); return string_to_impl(str, base); @@ -97,11 +98,28 @@ T string_to(std::string_view str, int base = 10) { template T string_to(std::string_view str, std::chars_format fmt = std::chars_format::general) { - static_assert(string_to_supports, +#ifndef __APPLE__ + static_assert(from_chars_supports, "Your C++ standard library is missing a std::from_chars " "implementation for this floating point type"); return string_to_impl(str, fmt); +#else + // For some reason it seems that Apple is unable to supply an implementation + // of std::from_chars so we need to work around its (potential) absence + if constexpr (from_chars_supports) { + // In case they update their standard lib… + return string_to_impl(str, fmt); + } + + // Workaround implementation that doesn't do any error checking - not great + // but better than not being able to use this function at all + std::stringstream stream{std::string(str)}; + T val = 0; + stream >> val; + + return val; +#endif } } // namespace sequant diff --git a/cmake/modules/FindOrFetchCppPeglib.cmake b/cmake/modules/FindOrFetchCppPeglib.cmake new file mode 100644 index 0000000000..865665f708 --- /dev/null +++ b/cmake/modules/FindOrFetchCppPeglib.cmake @@ -0,0 +1,17 @@ +if (NOT TARGET peglib) + include(FetchContent) + + FetchContent_Declare( + peglib + GIT_REPOSITORY "https://github.com/yhirose/cpp-peglib.git" + GIT_TAG "${SEQUANT_TRACKED_CPPPEGLIB_TAG}" + GIT_SHALLOW + ) + + FetchContent_MakeAvailable(peglib) +endif() + +# postcond check +if (NOT TARGET peglib) + message(FATAL_ERROR "FindOrFetchCppPeglib could not make TARGET peglib available") +endif() diff --git a/external/versions.cmake b/external/versions.cmake index c1989b9922..90c4b5fa87 100644 --- a/external/versions.cmake +++ b/external/versions.cmake @@ -44,3 +44,5 @@ set(SEQUANT_OLDEST_EIGEN_VERSION 3.0...5) set(SEQUANT_TRACKED_BTAS_TAG 9c8c8f68fee2b82e64755270a8348e4612cf9941) set(SEQUANT_TRACKED_TAPP_TAG 5d6fb56f7d4cbb4daefe23f408bbdde8cf9fc015) + +set(SEQUANT_TRACKED_CPPPEGLIB_TAG de57145d884e1eb6ca81b3972b588eb1cf5e643e) diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 8a39911424..12d8516724 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -18,8 +18,8 @@ set(symb_test_sources "test_math.cpp" "test_meta.cpp" "test_op.cpp" - "test_parse.cpp" "test_runtime.cpp" + "test_serialization.cpp" "test_space.cpp" "test_tensor.cpp" "test_tensor_network.cpp" diff --git a/tests/unit/test_parse.cpp b/tests/unit/test_parse.cpp deleted file mode 100644 index eeef461ebe..0000000000 --- a/tests/unit/test_parse.cpp +++ /dev/null @@ -1,530 +0,0 @@ -#include -#include - -#include "catch2_sequant.hpp" - -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace Catch { -template <> -struct StringMaker { - static std::string convert( - const sequant::io::serialization::SerializationError& error) { - return "io::serialization::SerializationError{offset: " + - std::to_string(error.offset) + - ", length: " + std::to_string(error.length) + ", what(): '" + - error.what() + "'}"; - } -}; -} // namespace Catch - -struct SerializationErrorMatcher - : Catch::Matchers::MatcherBase< - sequant::io::serialization::SerializationError> { - std::size_t offset; - std::size_t length; - std::string messageFragment; - - SerializationErrorMatcher(std::size_t offset, std::size_t length, - std::string messageFragment = "") - : offset(offset), - length(length), - messageFragment(std::move(messageFragment)) {} - - bool match(const sequant::io::serialization::SerializationError& exception) - const override { - if (exception.offset != offset) { - return false; - } - if (exception.length != length) { - return false; - } - if (!messageFragment.empty()) { - std::string message(exception.what()); - auto iter = message.find(messageFragment); - - return iter != std::string::npos; - } - return true; - } - - std::string describe() const override { - return "-- expected {offset: " + std::to_string(offset) + - ", length: " + std::to_string(length) + - (messageFragment.empty() - ? std::string("}") - : ", what() references '" + messageFragment + "'}"); - } -}; - -SerializationErrorMatcher serializationErrorMatches( - std::size_t offset, std::size_t length, std::string messageFragment = "") { - return SerializationErrorMatcher{offset, length, std::move(messageFragment)}; -} - -TEST_CASE("serialization", "[serialization]") { - SECTION("deserialize") { - using namespace sequant; - - auto ctx = get_default_context(); - ctx.set(mbpt::make_sr_spaces()); - auto ctx_resetter = set_scoped_default_context(ctx); - - SECTION("Scalar tensor") { - auto expr = deserialize(L"t{}"); - REQUIRE(expr->is()); - REQUIRE(expr->as().bra().empty()); - REQUIRE(expr->as().ket().empty()); - REQUIRE(expr->as().aux().empty()); - - REQUIRE(expr == deserialize(L"t{;}")); - REQUIRE(expr == deserialize(L"t{;;}")); - REQUIRE(expr == deserialize(L"t^{}_{}")); - REQUIRE(expr == deserialize(L"t_{}^{}")); - } - SECTION("Tensor") { - auto expr = deserialize(L"t{i1;a1}"); - REQUIRE(expr->is()); - REQUIRE(expr->as().label() == L"t"); - REQUIRE(expr->as().bra().size() == 1); - REQUIRE(expr->as().bra().at(0).label() == L"i_1"); - REQUIRE(expr->as().ket().size() == 1); - REQUIRE(expr->as().ket().at(0) == L"a_1"); - REQUIRE(expr->as().aux().empty()); - - REQUIRE(expr == deserialize(L"t_{i1}^{a1}")); - REQUIRE(expr == deserialize(L"t^{a1}_{i1}")); - REQUIRE(expr == deserialize(L"t{i_1; a_1}")); - REQUIRE(expr == deserialize(L"t{i_1; a_1;}")); - REQUIRE(expr == deserialize(L"t_{i_1}^{a_1}")); - - expr = deserialize(L"t{i1,i2;a1,a2}"); - REQUIRE(expr->as().bra().size() == 2); - REQUIRE(expr->as().bra().at(0).label() == L"i_1"); - REQUIRE(expr->as().bra().at(1).label() == L"i_2"); - REQUIRE(expr->as().ket().size() == 2); - REQUIRE(expr->as().ket().at(0).label() == L"a_1"); - REQUIRE(expr->as().ket().at(1).label() == L"a_2"); - REQUIRE(expr->as().aux().empty()); - - REQUIRE(expr == deserialize(L"+t{i1, i2; a1, a2}")); - REQUIRE(deserialize(L"-t{i1;a1}")->is()); - REQUIRE(expr == deserialize(L"t{\ti1, \ti2; \na1,\t a2 \t}")); - - // Tensor labels including underscores - REQUIRE(deserialize(L"T_1{i_1;a_1}")->as().label() == - L"T_1"); - - // "Non-standard" tensor labels - REQUIRE(deserialize(L"α{a1;i1}")->as().label() == L"α"); - REQUIRE(deserialize(L"γ_1{a1;i1}")->as().label() == - L"γ_1"); - REQUIRE(deserialize(L"t⁔1{a1;i1}")->as().label() == - L"t⁔1"); - REQUIRE(deserialize(L"t¹{a1;i1}")->as().label() == - L"t¹"); - REQUIRE(deserialize(L"t⁸{a1;i1}")->as().label() == - L"t⁸"); - REQUIRE(deserialize(L"t⁻{a1;i1}")->as().label() == - L"t⁻"); - REQUIRE(deserialize(L"tₐ{a1;i1}")->as().label() == - L"tₐ"); - REQUIRE(deserialize(L"t₋{a1;i1}")->as().label() == - L"t₋"); - REQUIRE(deserialize(L"t₌{a1;i1}")->as().label() == - L"t₌"); - REQUIRE(deserialize(L"t↓{a1;i1}")->as().label() == - L"t↓"); - REQUIRE(deserialize(L"t↑{a1;i1}")->as().label() == - L"t↑"); - - // "Non-standard" index names - auto expr1 = deserialize(L"t{a↓1;i↑1}"); - REQUIRE(expr1->as().bra().at(0).label() == L"a↓_1"); - REQUIRE(expr1->as().ket().at(0).label() == L"i↑_1"); - - // Auxiliary indices - expr = deserialize(L"t{;;i1}"); - REQUIRE(expr->is()); - REQUIRE(expr->as().bra().empty()); - REQUIRE(expr->as().ket().empty()); - REQUIRE(expr->as().aux().size() == 1); - REQUIRE(expr->as().aux()[0].label() == L"i_1"); - - // All index groups at once - expr = deserialize(L"t{i1,i2;a1;x1,x2}"); - REQUIRE(expr->is()); - REQUIRE(expr->as().bra().size() == 2); - REQUIRE(expr->as().bra().at(0).label() == L"i_1"); - REQUIRE(expr->as().bra().at(1).label() == L"i_2"); - REQUIRE(expr->as().ket().size() == 1); - REQUIRE(expr->as().ket().at(0).label() == L"a_1"); - REQUIRE(expr->as().aux().size() == 2); - REQUIRE(expr->as().aux().at(0).label() == L"x_1"); - REQUIRE(expr->as().aux().at(1).label() == L"x_2"); - } - - SECTION("Tensor with symmetry annotation") { - auto expr1 = deserialize(L"t{a1;i1}:A"); - auto expr2 = deserialize(L"t{a1;i1}:S-C"); - auto expr3 = deserialize(L"t{a1;i1}:N-S-N"); - - const Tensor& t1 = expr1->as(); - const Tensor& t2 = expr2->as(); - const Tensor& t3 = expr3->as(); - - REQUIRE(t1.symmetry() == Symmetry::Antisymm); - - REQUIRE(t2.symmetry() == Symmetry::Symm); - REQUIRE(t2.braket_symmetry() == BraKetSymmetry::Conjugate); - - REQUIRE(t3.symmetry() == Symmetry::Nonsymm); - REQUIRE(t3.braket_symmetry() == BraKetSymmetry::Symm); - REQUIRE(t3.column_symmetry() == ColumnSymmetry::Nonsymm); - } - - SECTION("NormalOperator") { - { - using NOp = FNOperator; - auto expr = deserialize(L"a{i1;a1}"); - REQUIRE(expr->is()); - REQUIRE(expr->as().label() == NOp::labels()[0]); - REQUIRE(expr->as().creators().size() == 1); - REQUIRE(expr->as().creators().at(0).index().label() == L"a_1"); - REQUIRE(expr->as().annihilators().size() == 1); - REQUIRE(expr->as().annihilators().at(0).index() == L"i_1"); - REQUIRE(expr->as().vacuum() == Vacuum::Physical); - } - { - using NOp = FNOperator; - auto expr = deserialize(L"ã{i1;}"); - REQUIRE(expr->is()); - REQUIRE(expr->as().label() == NOp::labels()[1]); - REQUIRE(expr->as().creators().size() == 0); - REQUIRE(expr->as().annihilators().size() == 1); - REQUIRE(expr->as().annihilators().at(0).index() == L"i_1"); - REQUIRE(expr->as().vacuum() == Vacuum::SingleProduct); - } - - { - using NOp = BNOperator; - auto expr = deserialize(L"b{i1;a1}"); - REQUIRE(expr->is()); - REQUIRE(expr->as().label() == NOp::labels()[0]); - REQUIRE(expr->as().creators().size() == 1); - REQUIRE(expr->as().creators().at(0).index().label() == L"a_1"); - REQUIRE(expr->as().annihilators().size() == 1); - REQUIRE(expr->as().annihilators().at(0).index() == L"i_1"); - REQUIRE(expr->as().vacuum() == Vacuum::Physical); - } - { - using NOp = BNOperator; - auto expr = deserialize(L"b̃{;a1}"); - REQUIRE(expr->is()); - REQUIRE(expr->as().label() == NOp::labels()[1]); - REQUIRE(expr->as().creators().size() == 1); - REQUIRE(expr->as().creators().at(0).index().label() == L"a_1"); - REQUIRE(expr->as().annihilators().size() == 0); - REQUIRE(expr->as().vacuum() == Vacuum::SingleProduct); - } - } - - SECTION("Constant") { - REQUIRE(deserialize(L"1/2")->is()); - REQUIRE(deserialize(L"0/2")->is()); - REQUIRE(deserialize(L"-1/2")->is()); - REQUIRE(deserialize(L"-0/2")->is()); - REQUIRE(deserialize(L"1")->is()); - REQUIRE(deserialize(L"123")->is()); - REQUIRE(deserialize(L"1.")->is()); - REQUIRE(deserialize(L"01.00")->is()); - REQUIRE(deserialize(L"0 / 10")->is()); - REQUIRE(deserialize(L"0.5/0.25")->is()); - REQUIRE(deserialize(L".4")->is()); - } - - SECTION("Variable") { - // SeQuant variable is just a label followed by an optional ^* - // to denote if the variable is conjugated - REQUIRE(deserialize(L"a")->is()); - REQUIRE(deserialize(L"α")->is()); - REQUIRE(deserialize(L"β")->is()); - REQUIRE(deserialize(L"γ")->is()); - REQUIRE(deserialize(L"λ")->is()); - REQUIRE(deserialize(L"δ")->is()); - REQUIRE(deserialize(L"a^*")->is()); - REQUIRE(deserialize(L"α^*")->is()); - REQUIRE(deserialize(L"β^*")->is()); - REQUIRE(deserialize(L"b^*")->is()); - REQUIRE(deserialize(L"b^*")->as().conjugated()); - REQUIRE(deserialize(L"b^*")->as().label() == L"b"); - } - - SECTION("Product") { - auto expr = deserialize(L"-1/2 g{i2,i3; i1,a2} t{a1,a2; i2,i3}"); - REQUIRE(expr->is()); - - auto const& prod = expr->as(); - REQUIRE(prod.scalar() == rational{-1, 2}); - REQUIRE(prod.factor(0) == - deserialize(L"g_{i_2, i_3}^{i_1, a_2}")); - REQUIRE(prod.factor(1) == deserialize(L"t^{i2, i3}_{a1, a2}")); - REQUIRE(deserialize(L"-1/2 * δ * t{i1;a1}") == - deserialize(L"-1/2 δ t{i1;a1}")); - auto const prod2 = - deserialize(L"-1/2 * δ * γ * t{i1;a1}")->as(); - REQUIRE(prod2.scalar() == rational{-1, 2}); - REQUIRE(prod2.factor(0) == ex(L"δ")); - REQUIRE(prod2.factor(1) == ex(L"γ")); - REQUIRE(prod2.factor(2)->is()); - } - - SECTION("Sum") { - auto expr1 = deserialize( - L"f{a1;i1}" - "- 1/2*g{i2,a1; a2,a3}t{a2,a3; i1,i2}"); - REQUIRE(expr1->is()); - - auto const& sum1 = expr1->as(); - REQUIRE(sum1.summand(0) == deserialize(L"f{a1;i1}")); - REQUIRE( - sum1.summand(1) == - deserialize(L"- 1/2 * g{i2,a1; a2,a3} * t{a2,a3; i1,i2}")); - - auto expr2 = deserialize(L"a - 4"); - REQUIRE(expr2->is()); - - auto const& sum2 = expr2->as(); - REQUIRE(sum2.summand(0) == deserialize(L"a")); - REQUIRE(sum2.summand(1) == deserialize(L"-4")); - } - - SECTION("Parentheses") { - auto expr1 = deserialize( - L"-1/2 g{i2,i3; a2,a3} * ( t{a1,a3; i2,i3} * t{a2;i1} )"); - REQUIRE(expr1->is()); - - auto const& prod1 = expr1->as(); - REQUIRE(prod1.size() == 2); - REQUIRE(prod1.scalar() == rational{-1, 2}); - REQUIRE(prod1.factor(0)->is()); - REQUIRE(prod1.factor(1)->is()); - REQUIRE(prod1.factor(1)->size() == 2); - - auto expr2 = deserialize( - L"(-1/2) ( g{i2,i3; a2,a3} * t{a1,a3; i2,i3} ) * (t{a2;i1})"); - REQUIRE(expr2->is()); - - auto const& prod2 = expr2->as(); - REQUIRE(prod2.size() == 2); - REQUIRE(prod2.scalar() == rational{-1, 2}); - REQUIRE(prod2.factor(0)->is()); - REQUIRE(prod2.factor(0)->at(0) == - deserialize(L"g{i2,i3; a2,a3}")); - REQUIRE(prod2.factor(0)->at(1) == - deserialize(L"t{a1,a3; i2,i3}")); - REQUIRE(prod2.factor(1) == deserialize(L"t{a2;i1}")); - - auto expr3 = deserialize( - L"(-1/2) ( g{i2,i3; a2,a3} * t{a1,a3; i2,i3} ) * (1/2) * " - L"((t{a2;i1}))"); - REQUIRE(expr3->is()); - - auto const& prod3 = expr3->as(); - REQUIRE(prod3.size() == 2); - REQUIRE(prod3.scalar() == rational{-1, 4}); - REQUIRE(prod3.factor(0)->is()); - REQUIRE(prod3.factor(0)->at(0) == - deserialize(L"g{i2,i3; a2,a3}")); - REQUIRE(prod3.factor(0)->at(1) == - deserialize(L"t{a1,a3; i2,i3}")); - REQUIRE(prod3.factor(1) == deserialize(L"t{a2;i1}")); - - auto expr4 = deserialize(L"1/2 (a + b) * c"); - REQUIRE(expr4->is()); - - const auto& prod4 = expr4->as(); - REQUIRE(prod4.size() == 2); - REQUIRE(prod4.scalar() == rational{1, 2}); - REQUIRE(prod4.factor(1)->is()); - REQUIRE(prod4.factor(1)->as().label() == L"c"); - REQUIRE(prod4.factor(0)->is()); - const auto& nestedSum = prod4.factor(0)->as(); - REQUIRE(nestedSum.size() == 2); - REQUIRE(nestedSum.summand(0)->is()); - REQUIRE(nestedSum.summand(0)->as().label() == L"a"); - REQUIRE(nestedSum.summand(1)->is()); - REQUIRE(nestedSum.summand(1)->as().label() == L"b"); - } - - SECTION("Mixed") { - auto expr = deserialize( - L"0.25 g{a1,a2; i1,i2}" - "+ 1/4 g{i3,i4; a3,a4} (t{a3;i1} * t{a4;i2}) * (t{a1;i3} * " - "t{a2;i4})"); - REQUIRE(expr->is()); - auto const& sum = expr->as(); - - REQUIRE(sum.size() == 2); - - REQUIRE(sum.summand(0)->is()); - REQUIRE(sum.summand(0)->as().scalar() == rational{1, 4}); - REQUIRE(sum.summand(0)->size() == 1); - REQUIRE(sum.summand(0)->at(0) == - deserialize(L"g{a1,a2; i1,i2}")); - - REQUIRE(sum.summand(1)->is()); - auto const& prod = sum.summand(1)->as(); - - REQUIRE(prod.scalar() == rational{1, 4}); - REQUIRE(prod.size() == 3); - REQUIRE(prod.factor(0) == deserialize(L"g{i3,i4; a3,a4}")); - - REQUIRE(prod.factor(1)->is()); - REQUIRE(prod.factor(1)->at(0) == deserialize(L"t{a3;i1}")); - REQUIRE(prod.factor(1)->at(1) == deserialize(L"t{a4;i2}")); - - REQUIRE(prod.factor(2)->is()); - REQUIRE(prod.factor(2)->at(0) == deserialize(L"t{a1;i3}")); - REQUIRE(prod.factor(2)->at(1) == deserialize(L"t{a2;i4}")); - } - - SECTION("Empty input") { REQUIRE(deserialize(L"") == nullptr); } - - SECTION("Error handling") { - SECTION("Exception type") { - std::vector inputs = {L"t^", - L"a + + b" - L"1/t" - L"T{}" - L"T^{i1}{a1}"}; - - for (const std::wstring& current : inputs) { - REQUIRE_THROWS_AS(deserialize(current), - io::serialization::SerializationError); - } - } - - SECTION("Invalid index") { - REQUIRE_THROWS_MATCHES(deserialize(L"t{i1;}"), - io::serialization::SerializationError, - serializationErrorMatches(5, 3, "proto")); - REQUIRE_THROWS_MATCHES( - deserialize(L"t{i1;az3}"), - io::serialization::SerializationError, - serializationErrorMatches(5, 3, "Unknown index space")); - } - - SECTION("Invalid symmetry") { - REQUIRE_THROWS_MATCHES( - deserialize(L"t{i1;a3}:P"), - io::serialization::SerializationError, - serializationErrorMatches(9, 1, "Invalid symmetry specifier")); - } - } - } - - SECTION("deserialize") { - using namespace sequant; - - SECTION("constant") { - ResultExpr result = deserialize(L"A = 3"); - - REQUIRE(result.has_label()); - REQUIRE(result.label() == L"A"); - REQUIRE(result.bra().empty()); - REQUIRE(result.ket().empty()); - REQUIRE(result.symmetry() == Symmetry::Nonsymm); - REQUIRE(result.braket_symmetry() == BraKetSymmetry::Nonsymm); - REQUIRE(result.column_symmetry() == ColumnSymmetry::Nonsymm); - - REQUIRE(result.expression().is()); - REQUIRE(result.expression().as().value() == 3); - } - SECTION("contraction") { - ResultExpr result = deserialize( - L"R{i1,i2;e1,e2}:A = f{e2;e3} t{e1,e3;i1,i2}"); - - REQUIRE(result.has_label()); - REQUIRE(result.label() == L"R"); - REQUIRE(result.bra().size() == 2); - REQUIRE(result.bra()[0].full_label() == L"i_1"); - REQUIRE(result.bra()[1].full_label() == L"i_2"); - REQUIRE(result.ket().size() == 2); - REQUIRE(result.ket()[0].full_label() == L"e_1"); - REQUIRE(result.ket()[1].full_label() == L"e_2"); - REQUIRE(result.symmetry() == Symmetry::Antisymm); - REQUIRE(result.braket_symmetry() == - get_default_context().braket_symmetry()); - REQUIRE(result.column_symmetry() == ColumnSymmetry::Symm); - - REQUIRE(result.expression().is()); - const Product& prod = result.expression().as(); - REQUIRE(prod.size() == 2); - REQUIRE(prod.factor(0).is()); - REQUIRE(prod.factor(0).as().label() == L"f"); - REQUIRE(prod.factor(1).is()); - REQUIRE(prod.factor(1).as().label() == L"t"); - } - } - - SECTION("serialize") { - using namespace sequant; - - std::vector expressions = { - L"t{a_1,a_2;a_3,a_4}:N-C-S", - L"42", - L"1/2", - L"-1/4 t{a_1,i_1;a_2,i_2}:S-N-S", - L"a + b - 4 specialVariable", - L"variable + A{a_1;i_1}:N-N-S * B{i_1;a_1}:A-C-S", - L"1/2 (a + b) * c", - L"T1{}:N-N-N + T2{;;x_1}:N-N-N * T3{;;x_1}:N-N-N " - L"+ T4{a_1;;x_2}:S-C-S * T5{;a_1;x_2}:S-S-S", - L"q1 * q2^* * q3", - L"1/2 ã{i_1;i_2} * b̃{i_3;i_4}"}; - - for (const std::wstring& current : expressions) { - ExprPtr expression = deserialize(current); - - REQUIRE(serialize(expression, {.annot_symm = true}) == current); - } - - SECTION("result_expressions") { - std::vector expressions = { - L"A = 5", - L"A = g{i_1,i_2;e_1,e_2}:S-N-S * t{e_1,e_2;i_1,i_2}:N-N-S", - L"R{i_1,i_2;e_1,e_2}:A-N-S = f{e_2;e_3}:A-N-S * " - L"t{e_1,e_3;i_1,i_2}:A-N-S + " - L"g{i_1,i_2;e_1,e_2}:A-N-S", - }; - - for (const std::wstring& current : expressions) { - ResultExpr result = deserialize(current); - - REQUIRE(serialize(result, {.annot_symm = true}) == current); - } - } - } -} diff --git a/tests/unit/test_serialization.cpp b/tests/unit/test_serialization.cpp new file mode 100644 index 0000000000..facf4f0184 --- /dev/null +++ b/tests/unit/test_serialization.cpp @@ -0,0 +1,561 @@ +#include +#include +#include + +#include "catch2_sequant.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace Catch { +template <> +struct StringMaker { + static std::string convert( + const sequant::io::serialization::v1::Error& error) { + return "io::serialization::v1::Error{offset: " + + std::to_string(error.offset) + + ", length: " + std::to_string(error.length) + ", what(): '" + + error.what() + "'}"; + } +}; +} // namespace Catch + +struct SerializationErrorV1Matcher + : Catch::Matchers::MatcherBase { + std::size_t offset; + std::size_t length; + std::string messageFragment; + + SerializationErrorV1Matcher(std::size_t offset, std::size_t length, + std::string messageFragment = "") + : offset(offset), + length(length), + messageFragment(std::move(messageFragment)) {} + + bool match( + const sequant::io::serialization::v1::Error& exception) const override { + if (exception.offset != offset) { + return false; + } + if (exception.length != length) { + return false; + } + if (!messageFragment.empty()) { + std::string message(exception.what()); + auto iter = message.find(messageFragment); + + return iter != std::string::npos; + } + return true; + } + + std::string describe() const override { + return "-- expected {offset: " + std::to_string(offset) + + ", length: " + std::to_string(length) + + (messageFragment.empty() + ? std::string("}") + : ", what() references '" + messageFragment + "'}"); + } +}; + +SerializationErrorV1Matcher serializationErrorV1Matches( + std::size_t offset, std::size_t length, std::string messageFragment = "") { + return SerializationErrorV1Matcher{offset, length, + std::move(messageFragment)}; +} + +TEST_CASE("serialization", "[serialization]") { + using namespace sequant; + using namespace io::serialization; + SECTION("v1") { + SECTION("from_string") { + auto ctx = get_default_context(); + ctx.set(mbpt::make_sr_spaces()); + auto ctx_resetter = set_scoped_default_context(ctx); + + SECTION("Scalar tensor") { + auto expr = v1::from_string(L"t{}"); + REQUIRE(expr->is()); + REQUIRE(expr->as().bra().empty()); + REQUIRE(expr->as().ket().empty()); + REQUIRE(expr->as().aux().empty()); + + REQUIRE(expr == v1::from_string(L"t{;}")); + REQUIRE(expr == v1::from_string(L"t{;;}")); + REQUIRE(expr == v1::from_string(L"t^{}_{}")); + REQUIRE(expr == v1::from_string(L"t_{}^{}")); + } + SECTION("Tensor") { + auto expr = v1::from_string(L"t{i1;a1}"); + REQUIRE(expr->is()); + REQUIRE(expr->as().label() == L"t"); + REQUIRE(expr->as().bra().size() == 1); + REQUIRE(expr->as().bra().at(0).label() == L"i_1"); + REQUIRE(expr->as().ket().size() == 1); + REQUIRE(expr->as().ket().at(0) == L"a_1"); + REQUIRE(expr->as().aux().empty()); + + REQUIRE(expr == v1::from_string(L"t_{i1}^{a1}")); + REQUIRE(expr == v1::from_string(L"t^{a1}_{i1}")); + REQUIRE(expr == v1::from_string(L"t{i_1; a_1}")); + REQUIRE(expr == v1::from_string(L"t{i_1; a_1;}")); + REQUIRE(expr == v1::from_string(L"t_{i_1}^{a_1}")); + + expr = v1::from_string(L"t{i1,i2;a1,a2}"); + REQUIRE(expr->as().bra().size() == 2); + REQUIRE(expr->as().bra().at(0).label() == L"i_1"); + REQUIRE(expr->as().bra().at(1).label() == L"i_2"); + REQUIRE(expr->as().ket().size() == 2); + REQUIRE(expr->as().ket().at(0).label() == L"a_1"); + REQUIRE(expr->as().ket().at(1).label() == L"a_2"); + REQUIRE(expr->as().aux().empty()); + + REQUIRE(expr == v1::from_string(L"+t{i1, i2; a1, a2}")); + REQUIRE(v1::from_string(L"-t{i1;a1}")->is()); + REQUIRE(expr == + v1::from_string(L"t{\ti1, \ti2; \na1,\t a2 \t}")); + + // Tensor labels including underscores + REQUIRE( + v1::from_string(L"T_1{i_1;a_1}")->as().label() == + L"T_1"); + + // "Non-standard" tensor labels + REQUIRE(v1::from_string(L"α{a1;i1}")->as().label() == + L"α"); + REQUIRE(v1::from_string(L"γ_1{a1;i1}")->as().label() == + L"γ_1"); + REQUIRE(v1::from_string(L"t⁔1{a1;i1}")->as().label() == + L"t⁔1"); + REQUIRE(v1::from_string(L"t¹{a1;i1}")->as().label() == + L"t¹"); + REQUIRE(v1::from_string(L"t⁸{a1;i1}")->as().label() == + L"t⁸"); + REQUIRE(v1::from_string(L"t⁻{a1;i1}")->as().label() == + L"t⁻"); + REQUIRE(v1::from_string(L"tₐ{a1;i1}")->as().label() == + L"tₐ"); + REQUIRE(v1::from_string(L"t₋{a1;i1}")->as().label() == + L"t₋"); + REQUIRE(v1::from_string(L"t₌{a1;i1}")->as().label() == + L"t₌"); + REQUIRE(v1::from_string(L"t↓{a1;i1}")->as().label() == + L"t↓"); + REQUIRE(v1::from_string(L"t↑{a1;i1}")->as().label() == + L"t↑"); + + // "Non-standard" index names + auto expr1 = v1::from_string(L"t{a↓1;i↑1}"); + REQUIRE(expr1->as().bra().at(0).label() == L"a↓_1"); + REQUIRE(expr1->as().ket().at(0).label() == L"i↑_1"); + + // Auxiliary indices + expr = v1::from_string(L"t{;;i1}"); + REQUIRE(expr->is()); + REQUIRE(expr->as().bra().empty()); + REQUIRE(expr->as().ket().empty()); + REQUIRE(expr->as().aux().size() == 1); + REQUIRE(expr->as().aux()[0].label() == L"i_1"); + + // All index groups at once + expr = v1::from_string(L"t{i1,i2;a1;x1,x2}"); + REQUIRE(expr->is()); + REQUIRE(expr->as().bra().size() == 2); + REQUIRE(expr->as().bra().at(0).label() == L"i_1"); + REQUIRE(expr->as().bra().at(1).label() == L"i_2"); + REQUIRE(expr->as().ket().size() == 1); + REQUIRE(expr->as().ket().at(0).label() == L"a_1"); + REQUIRE(expr->as().aux().size() == 2); + REQUIRE(expr->as().aux().at(0).label() == L"x_1"); + REQUIRE(expr->as().aux().at(1).label() == L"x_2"); + } + + SECTION("Tensor with symmetry annotation") { + auto expr1 = v1::from_string(L"t{a1;i1}:A"); + auto expr2 = v1::from_string(L"t{a1;i1}:S-C"); + auto expr3 = v1::from_string(L"t{a1;i1}:N-S-N"); + + const Tensor& t1 = expr1->as(); + const Tensor& t2 = expr2->as(); + const Tensor& t3 = expr3->as(); + + REQUIRE(t1.symmetry() == Symmetry::Antisymm); + + REQUIRE(t2.symmetry() == Symmetry::Symm); + REQUIRE(t2.braket_symmetry() == BraKetSymmetry::Conjugate); + + REQUIRE(t3.symmetry() == Symmetry::Nonsymm); + REQUIRE(t3.braket_symmetry() == BraKetSymmetry::Symm); + REQUIRE(t3.column_symmetry() == ColumnSymmetry::Nonsymm); + } + + SECTION("NormalOperator") { + { + using NOp = FNOperator; + auto expr = v1::from_string(L"a{i1;a1}"); + REQUIRE(expr->is()); + REQUIRE(expr->as().label() == NOp::labels()[0]); + REQUIRE(expr->as().creators().size() == 1); + REQUIRE(expr->as().creators().at(0).index().label() == L"a_1"); + REQUIRE(expr->as().annihilators().size() == 1); + REQUIRE(expr->as().annihilators().at(0).index() == L"i_1"); + REQUIRE(expr->as().vacuum() == Vacuum::Physical); + } + { + using NOp = FNOperator; + auto expr = v1::from_string(L"ã{i1;}"); + REQUIRE(expr->is()); + REQUIRE(expr->as().label() == NOp::labels()[1]); + REQUIRE(expr->as().creators().size() == 0); + REQUIRE(expr->as().annihilators().size() == 1); + REQUIRE(expr->as().annihilators().at(0).index() == L"i_1"); + REQUIRE(expr->as().vacuum() == Vacuum::SingleProduct); + } + + { + using NOp = BNOperator; + auto expr = v1::from_string(L"b{i1;a1}"); + REQUIRE(expr->is()); + REQUIRE(expr->as().label() == NOp::labels()[0]); + REQUIRE(expr->as().creators().size() == 1); + REQUIRE(expr->as().creators().at(0).index().label() == L"a_1"); + REQUIRE(expr->as().annihilators().size() == 1); + REQUIRE(expr->as().annihilators().at(0).index() == L"i_1"); + REQUIRE(expr->as().vacuum() == Vacuum::Physical); + } + { + using NOp = BNOperator; + auto expr = v1::from_string(L"b̃{;a1}"); + REQUIRE(expr->is()); + REQUIRE(expr->as().label() == NOp::labels()[1]); + REQUIRE(expr->as().creators().size() == 1); + REQUIRE(expr->as().creators().at(0).index().label() == L"a_1"); + REQUIRE(expr->as().annihilators().size() == 0); + REQUIRE(expr->as().vacuum() == Vacuum::SingleProduct); + } + } + + SECTION("Constant") { + REQUIRE(v1::from_string(L"1/2")->is()); + REQUIRE(v1::from_string(L"0/2")->is()); + REQUIRE(v1::from_string(L"-1/2")->is()); + REQUIRE(v1::from_string(L"-0/2")->is()); + REQUIRE(v1::from_string(L"1")->is()); + REQUIRE(v1::from_string(L"123")->is()); + REQUIRE(v1::from_string(L"1.")->is()); + REQUIRE(v1::from_string(L"01.00")->is()); + REQUIRE(v1::from_string(L"0 / 10")->is()); + REQUIRE(v1::from_string(L"0.5/0.25")->is()); + REQUIRE(v1::from_string(L".4")->is()); + } + + SECTION("Variable") { + // SeQuant variable is just a label followed by an optional ^* + // to denote if the variable is conjugated + REQUIRE(v1::from_string(L"a")->is()); + REQUIRE(v1::from_string(L"α")->is()); + REQUIRE(v1::from_string(L"β")->is()); + REQUIRE(v1::from_string(L"γ")->is()); + REQUIRE(v1::from_string(L"λ")->is()); + REQUIRE(v1::from_string(L"δ")->is()); + REQUIRE(v1::from_string(L"a^*")->is()); + REQUIRE(v1::from_string(L"α^*")->is()); + REQUIRE(v1::from_string(L"β^*")->is()); + REQUIRE(v1::from_string(L"b^*")->is()); + REQUIRE(v1::from_string(L"b^*")->as().conjugated()); + REQUIRE(v1::from_string(L"b^*")->as().label() == + L"b"); + } + + SECTION("Product") { + auto expr = + v1::from_string(L"-1/2 g{i2,i3; i1,a2} t{a1,a2; i2,i3}"); + REQUIRE(expr->is()); + + auto const& prod = expr->as(); + REQUIRE(prod.scalar() == rational{-1, 2}); + REQUIRE(prod.factor(0) == + v1::from_string(L"g_{i_2, i_3}^{i_1, a_2}")); + REQUIRE(prod.factor(1) == + v1::from_string(L"t^{i2, i3}_{a1, a2}")); + REQUIRE(v1::from_string(L"-1/2 * δ * t{i1;a1}") == + v1::from_string(L"-1/2 δ t{i1;a1}")); + auto const prod2 = + v1::from_string(L"-1/2 * δ * γ * t{i1;a1}")->as(); + REQUIRE(prod2.scalar() == rational{-1, 2}); + REQUIRE(prod2.factor(0) == ex(L"δ")); + REQUIRE(prod2.factor(1) == ex(L"γ")); + REQUIRE(prod2.factor(2)->is()); + } + + SECTION("Sum") { + auto expr1 = v1::from_string( + L"f{a1;i1}" + "- 1/2*g{i2,a1; a2,a3}t{a2,a3; i1,i2}"); + REQUIRE(expr1->is()); + + auto const& sum1 = expr1->as(); + REQUIRE(sum1.summand(0) == v1::from_string(L"f{a1;i1}")); + REQUIRE(sum1.summand(1) == + v1::from_string( + L"- 1/2 * g{i2,a1; a2,a3} * t{a2,a3; i1,i2}")); + + auto expr2 = v1::from_string(L"a - 4"); + REQUIRE(expr2->is()); + + auto const& sum2 = expr2->as(); + REQUIRE(sum2.summand(0) == v1::from_string(L"a")); + REQUIRE(sum2.summand(1) == v1::from_string(L"-4")); + } + + SECTION("Parentheses") { + auto expr1 = v1::from_string( + L"-1/2 g{i2,i3; a2,a3} * ( t{a1,a3; i2,i3} * t{a2;i1} )"); + REQUIRE(expr1->is()); + + auto const& prod1 = expr1->as(); + REQUIRE(prod1.size() == 2); + REQUIRE(prod1.scalar() == rational{-1, 2}); + REQUIRE(prod1.factor(0)->is()); + REQUIRE(prod1.factor(1)->is()); + REQUIRE(prod1.factor(1)->size() == 2); + + auto expr2 = v1::from_string( + L"(-1/2) ( g{i2,i3; a2,a3} * t{a1,a3; i2,i3} ) * (t{a2;i1})"); + REQUIRE(expr2->is()); + + auto const& prod2 = expr2->as(); + REQUIRE(prod2.size() == 2); + REQUIRE(prod2.scalar() == rational{-1, 2}); + REQUIRE(prod2.factor(0)->is()); + REQUIRE(prod2.factor(0)->at(0) == + v1::from_string(L"g{i2,i3; a2,a3}")); + REQUIRE(prod2.factor(0)->at(1) == + v1::from_string(L"t{a1,a3; i2,i3}")); + REQUIRE(prod2.factor(1) == v1::from_string(L"t{a2;i1}")); + + auto expr3 = v1::from_string( + L"(-1/2) ( g{i2,i3; a2,a3} * t{a1,a3; i2,i3} ) * (1/2) * " + L"((t{a2;i1}))"); + REQUIRE(expr3->is()); + + auto const& prod3 = expr3->as(); + REQUIRE(prod3.size() == 2); + REQUIRE(prod3.scalar() == rational{-1, 4}); + REQUIRE(prod3.factor(0)->is()); + REQUIRE(prod3.factor(0)->at(0) == + v1::from_string(L"g{i2,i3; a2,a3}")); + REQUIRE(prod3.factor(0)->at(1) == + v1::from_string(L"t{a1,a3; i2,i3}")); + REQUIRE(prod3.factor(1) == v1::from_string(L"t{a2;i1}")); + + auto expr4 = v1::from_string(L"1/2 (a + b) * c"); + REQUIRE(expr4->is()); + + const auto& prod4 = expr4->as(); + REQUIRE(prod4.size() == 2); + REQUIRE(prod4.scalar() == rational{1, 2}); + REQUIRE(prod4.factor(1)->is()); + REQUIRE(prod4.factor(1)->as().label() == L"c"); + REQUIRE(prod4.factor(0)->is()); + const auto& nestedSum = prod4.factor(0)->as(); + REQUIRE(nestedSum.size() == 2); + REQUIRE(nestedSum.summand(0)->is()); + REQUIRE(nestedSum.summand(0)->as().label() == L"a"); + REQUIRE(nestedSum.summand(1)->is()); + REQUIRE(nestedSum.summand(1)->as().label() == L"b"); + } + + SECTION("Mixed") { + auto expr = v1::from_string( + L"0.25 g{a1,a2; i1,i2}" + "+ 1/4 g{i3,i4; a3,a4} (t{a3;i1} * t{a4;i2}) * (t{a1;i3} * " + "t{a2;i4})"); + REQUIRE(expr->is()); + auto const& sum = expr->as(); + + REQUIRE(sum.size() == 2); + + REQUIRE(sum.summand(0)->is()); + REQUIRE(sum.summand(0)->as().scalar() == rational{1, 4}); + REQUIRE(sum.summand(0)->size() == 1); + REQUIRE(sum.summand(0)->at(0) == + v1::from_string(L"g{a1,a2; i1,i2}")); + + REQUIRE(sum.summand(1)->is()); + auto const& prod = sum.summand(1)->as(); + + REQUIRE(prod.scalar() == rational{1, 4}); + REQUIRE(prod.size() == 3); + REQUIRE(prod.factor(0) == v1::from_string(L"g{i3,i4; a3,a4}")); + + REQUIRE(prod.factor(1)->is()); + REQUIRE(prod.factor(1)->at(0) == v1::from_string(L"t{a3;i1}")); + REQUIRE(prod.factor(1)->at(1) == v1::from_string(L"t{a4;i2}")); + + REQUIRE(prod.factor(2)->is()); + REQUIRE(prod.factor(2)->at(0) == v1::from_string(L"t{a1;i3}")); + REQUIRE(prod.factor(2)->at(1) == v1::from_string(L"t{a2;i4}")); + } + + SECTION("Empty input") { + REQUIRE(v1::from_string(L"") == nullptr); + } + + SECTION("Error handling") { + SECTION("Exception type") { + std::vector inputs = {L"t^", + L"a + + b" + L"1/t" + L"T{}" + L"T^{i1}{a1}"}; + + for (const std::wstring& current : inputs) { + REQUIRE_THROWS_AS(v1::from_string(current), + io::serialization::v1::Error); + } + } + + SECTION("Invalid index") { + REQUIRE_THROWS_MATCHES(v1::from_string(L"t{i1;}"), + io::serialization::v1::Error, + serializationErrorV1Matches(5, 3, "proto")); + REQUIRE_THROWS_MATCHES( + v1::from_string(L"t{i1;az3}"), + io::serialization::v1::Error, + serializationErrorV1Matches(5, 3, "Unknown index space")); + } + + SECTION("Invalid symmetry") { + REQUIRE_THROWS_MATCHES( + v1::from_string(L"t{i1;a3}:P"), + io::serialization::v1::Error, + serializationErrorV1Matches(9, 1, "Invalid symmetry specifier")); + } + } + } + + SECTION("from_string") { + SECTION("constant") { + ResultExpr result = v1::from_string(L"A = 3"); + + REQUIRE(result.has_label()); + REQUIRE(result.label() == L"A"); + REQUIRE(result.bra().empty()); + REQUIRE(result.ket().empty()); + REQUIRE(result.symmetry() == Symmetry::Nonsymm); + REQUIRE(result.braket_symmetry() == BraKetSymmetry::Nonsymm); + REQUIRE(result.column_symmetry() == ColumnSymmetry::Nonsymm); + + REQUIRE(result.expression().is()); + REQUIRE(result.expression().as().value() == 3); + } + SECTION("contraction") { + ResultExpr result = v1::from_string( + L"R{i1,i2;e1,e2}:A = f{e2;e3} t{e1,e3;i1,i2}"); + + REQUIRE(result.has_label()); + REQUIRE(result.label() == L"R"); + REQUIRE(result.bra().size() == 2); + REQUIRE(result.bra()[0].full_label() == L"i_1"); + REQUIRE(result.bra()[1].full_label() == L"i_2"); + REQUIRE(result.ket().size() == 2); + REQUIRE(result.ket()[0].full_label() == L"e_1"); + REQUIRE(result.ket()[1].full_label() == L"e_2"); + REQUIRE(result.symmetry() == Symmetry::Antisymm); + REQUIRE(result.braket_symmetry() == + get_default_context().braket_symmetry()); + REQUIRE(result.column_symmetry() == ColumnSymmetry::Symm); + + REQUIRE(result.expression().is()); + const Product& prod = result.expression().as(); + REQUIRE(prod.size() == 2); + REQUIRE(prod.factor(0).is()); + REQUIRE(prod.factor(0).as().label() == L"f"); + REQUIRE(prod.factor(1).is()); + REQUIRE(prod.factor(1).as().label() == L"t"); + } + } + + SECTION("to_string") { + std::vector expressions = { + L"t{a_1,a_2;a_3,a_4}:N-C-S", + L"42", + L"1/2", + L"-1/4 t{a_1,i_1;a_2,i_2}:S-N-S", + L"a + b - 4 specialVariable", + L"variable + A{a_1;i_1}:N-N-S * B{i_1;a_1}:A-C-S", + L"1/2 (a + b) * c", + L"T1{}:N-N-N + T2{;;x_1}:N-N-N * T3{;;x_1}:N-N-N " + L"+ T4{a_1;;x_2}:S-C-S * T5{;a_1;x_2}:S-S-S", + L"q1 * q2^* * q3", + L"1/2 ã{i_1;i_2} * b̃{i_3;i_4}"}; + + for (const std::wstring& current : expressions) { + ExprPtr expression = v1::from_string(current); + + REQUIRE(v1::to_string(expression, {.annot_symm = true}) == current); + } + + SECTION("result_expressions") { + std::vector expressions = { + L"A = 5", + L"A = g{i_1,i_2;e_1,e_2}:S-N-S * t{e_1,e_2;i_1,i_2}:N-N-S", + L"R{i_1,i_2;e_1,e_2}:A-N-S = f{e_2;e_3}:A-N-S * " + L"t{e_1,e_3;i_1,i_2}:A-N-S + " + L"g{i_1,i_2;e_1,e_2}:A-N-S", + }; + + for (const std::wstring& current : expressions) { + ResultExpr result = v1::from_string(current); + + REQUIRE(v1::to_string(result, {.annot_symm = true}) == current); + } + } + } + } + + SECTION("v2") { + SECTION("from_string") { + SECTION("Constant") { + SECTION("Integer") { + Constant constant = v2::from_string("5"); + REQUIRE(constant.value() == 5); + } + SECTION("Float") { + Constant constant = v2::from_string("3.14"); + REQUIRE_THAT(constant.value(), + ::Catch::Matchers::WithinAbs(3.14, 1e-6)); + } + SECTION("Imaginary") { + Constant constant = v2::from_string("2i"); + REQUIRE(constant.value() == Constant::scalar_type(0, 2)); + + constant = v2::from_string("3 i"); + REQUIRE(constant.value() == Constant::scalar_type(0, 3)); + } + } + } + } +} diff --git a/tests/unit/test_utilities.cpp b/tests/unit/test_utilities.cpp index 761e5f89b4..9a00494f5b 100644 --- a/tests/unit/test_utilities.cpp +++ b/tests/unit/test_utilities.cpp @@ -630,20 +630,18 @@ TEST_CASE("utilities", "[utilities]") { "'4 ' could not be fully parsed as a"))); } SECTION("float") { - if constexpr (string_to_supports) { - REQUIRE_THAT(string_to("42"), - WithinAbs(42, std::numeric_limits::epsilon())); - REQUIRE_THAT(string_to("3.14"), - WithinAbs(3.14, std::numeric_limits::epsilon())); - REQUIRE_THAT( - string_to("-3.14159"), - WithinAbs(-3.14159, std::numeric_limits::epsilon())); - if constexpr (string_to_supports) { - REQUIRE_THAT(string_to("2.7182818284590"), - WithinAbs(2.7182818284590, - std::numeric_limits::epsilon())); - } + REQUIRE_THAT(string_to("42"), + WithinAbs(42, std::numeric_limits::epsilon())); + REQUIRE_THAT(string_to("3.14"), + WithinAbs(3.14, std::numeric_limits::epsilon())); + REQUIRE_THAT(string_to("-3.14159"), + WithinAbs(-3.14159, std::numeric_limits::epsilon())); + REQUIRE_THAT( + string_to("2.7182818284590"), + WithinAbs(2.7182818284590, std::numeric_limits::epsilon())); + // Error handling is only available if std::from_chars is supported + if constexpr (from_chars_supports) { REQUIRE_THROWS_MATCHES( string_to(" 3.14"), ConversionException, MessageMatches(ContainsSubstring("' 3.14' is not a valid")));