Skip to content

Commit

Permalink
[Runtime] Handle runtime errors and assertions via throwing a descrip…
Browse files Browse the repository at this point in the history
…tive and unified expression (#92)

* Add a descriptive custom exception for Catalyst Runtime

* Update changelog

* Propagate runtime error messages via stderr (TODO)
  • Loading branch information
maliasadi committed Apr 14, 2023
1 parent 2e03d63 commit c9032f5
Show file tree
Hide file tree
Showing 14 changed files with 240 additions and 150 deletions.
6 changes: 5 additions & 1 deletion doc/changelog.md
Expand Up @@ -4,6 +4,10 @@

<h3>Improvements</h3>

* Improving error handling by throwing descriptive and unified expressions for runtime
errors and assertions.
[#92](https://github.com/PennyLaneAI/catalyst/pull/92)

<h3>Breaking changes</h3>

<h3>Bug fixes</h3>
Expand All @@ -18,7 +22,7 @@

This release contains contributions from (in alphabetical order):

Erick Ochoa Lopez
Ali Asadi, Erick Ochoa Lopez

# Release 0.1.2

Expand Down
91 changes: 91 additions & 0 deletions runtime/include/Exception.hpp
@@ -0,0 +1,91 @@
// Copyright 2023 Xanadu Quantum Technologies Inc.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <exception>
#include <iostream>

#include <sstream>
#include <string>
#include <type_traits>
#include <utility>

/**
* @brief Macro that throws `RuntimeException` with given message.
*/
#define RT_FAIL(message) Catalyst::Runtime::_abort((message), __FILE__, __LINE__, __func__)

/**
* @brief Macro that throws `RuntimeException` if expression evaluates
* to true.
*/
#define RT_FAIL_IF(expression, message) \
if ((expression)) { \
RT_FAIL(message); \
}

/**
* @brief Macro that throws `RuntimeException` with the given expression
* and source location if expression evaluates to false.
*/
#define RT_ASSERT(expression) RT_FAIL_IF(!(expression), "Assertion: " #expression)

namespace Catalyst::Runtime {

/**
* @brief This is the general exception thrown by Catalyst for runtime errors
* that is derived from `std::exception`.
*/
class RuntimeException : public std::exception {
private:
const std::string err_msg;

public:
explicit RuntimeException(std::string msg) noexcept
: err_msg{std::move(msg)} {} // LCOV_EXCL_LINE
~RuntimeException() override = default; // LCOV_EXCL_LINE

RuntimeException(const RuntimeException &) = default;
RuntimeException(RuntimeException &&) noexcept = default;

RuntimeException &operator=(const RuntimeException &) = delete;
RuntimeException &operator=(RuntimeException &&) = delete;

[[nodiscard]] auto what() const noexcept -> const char * override
{
return err_msg.c_str();
} // LCOV_EXCL_LINE
};

/**
* @brief Throws a `RuntimeException` with the given error message.
*
* @note This is not supposed to be called directly.
*/
[[noreturn]] inline void _abort(const char *message, const char *file_name, size_t line,
const char *function_name)
{
std::stringstream sstream;
sstream << "[" << file_name << "][Line:" << line << "][Function:" << function_name
<< "] Error in Catalyst Runtime: " << message;

// TODO: This should be removed after runtime error
// messages can propagate in the frontend.
std::cerr << sstream.str() << std::endl;

throw RuntimeException(sstream.str());
} // LCOV_EXCL_LINE

} // namespace Catalyst::Runtime
26 changes: 0 additions & 26 deletions runtime/lib/backend/BaseUtils.hpp

This file was deleted.

39 changes: 20 additions & 19 deletions runtime/lib/backend/LightningKokkosSimulator.cpp
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

#include "LightningKokkosSimulator.hpp"
#include "BaseUtils.hpp"

namespace Catalyst::Runtime::Simulator {

Expand Down Expand Up @@ -47,7 +46,7 @@ auto LightningKokkosSimulator::GetNumQubits() const -> size_t

void LightningKokkosSimulator::StartTapeRecording()
{
QFailIf(this->cache_recording, "Cannot re-activate the cache manager");
RT_FAIL_IF(this->cache_recording, "Cannot re-activate the cache manager");
this->cache_recording = true;
this->cache_manager.Reset();
}
Expand Down Expand Up @@ -113,8 +112,8 @@ void LightningKokkosSimulator::NamedOperation(const std::string &name,
Lightning::lookup_gates(Lightning::simulator_gate_info, name);

// Check the validity of number of qubits and parameters
QFailIf((!wires.size() && wires.size() != op_num_wires), "Invalid number of qubits");
QFailIf(params.size() != op_num_params, "Invalid number of parameters");
RT_FAIL_IF((!wires.size() && wires.size() != op_num_wires), "Invalid number of qubits");
RT_FAIL_IF(params.size() != op_num_params, "Invalid number of parameters");

// Convert wires to device wires
auto &&dev_wires = getDeviceWires(wires);
Expand All @@ -135,7 +134,7 @@ void LightningKokkosSimulator::MatrixOperation(const std::vector<std::complex<do
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;

// Check the validity of number of qubits and parameters
QFailIf(!wires.size(), "Invalid number of qubits");
RT_FAIL_IF(!wires.size(), "Invalid number of qubits");

// Convert wires to device wires
auto &&dev_wires = getDeviceWires(wires);
Expand All @@ -160,8 +159,8 @@ void LightningKokkosSimulator::MatrixOperation(const std::vector<std::complex<do
auto LightningKokkosSimulator::Observable(ObsId id, const std::vector<std::complex<double>> &matrix,
const std::vector<QubitIdType> &wires) -> ObsIdType
{
QFailIf(wires.size() > this->GetNumQubits(), "Invalid number of wires");
QFailIf(!isValidQubits(wires), "Invalid given wires");
RT_FAIL_IF(wires.size() > this->GetNumQubits(), "Invalid number of wires");
RT_FAIL_IF(!isValidQubits(wires), "Invalid given wires");

auto &&dev_wires = getDeviceWires(wires);

Expand Down Expand Up @@ -208,7 +207,7 @@ inline auto getRealOfComplexInnerProduct(Kokkos::View<Kokkos::complex<Precision>
Kokkos::View<Kokkos::complex<Precision> *> sv2_vec)
-> Precision
{
assert(sv1_vec.size() == sv2_vec.size());
RT_ASSERT(sv1_vec.size() == sv2_vec.size());
Precision inner = 0;
Kokkos::parallel_reduce(
sv1_vec.size(), getRealOfComplexInnerProductFunctor<Precision>(sv1_vec, sv2_vec), inner);
Expand All @@ -220,7 +219,8 @@ auto LightningKokkosSimulator::Expval(ObsIdType obsKey) -> double
using UnmanagedComplexHostView = Kokkos::View<Kokkos::complex<double> *, Kokkos::HostSpace,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;

QFailIf(!this->obs_manager.isValidObservables({obsKey}), "Invalid key for cached observables");
RT_FAIL_IF(!this->obs_manager.isValidObservables({obsKey}),
"Invalid key for cached observables");

// update tape caching
if (this->cache_recording) {
Expand All @@ -236,7 +236,8 @@ auto LightningKokkosSimulator::Expval(ObsIdType obsKey) -> double

auto LightningKokkosSimulator::Var(ObsIdType obsKey) -> double
{
QFailIf(!this->obs_manager.isValidObservables({obsKey}), "Invalid key for cached observables");
RT_FAIL_IF(!this->obs_manager.isValidObservables({obsKey}),
"Invalid key for cached observables");

// update tape caching
if (this->cache_recording) {
Expand Down Expand Up @@ -280,8 +281,8 @@ auto LightningKokkosSimulator::PartialProbs(const std::vector<QubitIdType> &wire
const size_t numWires = wires.size();
const size_t numQubits = this->GetNumQubits();

QFailIf(numWires > numQubits, "Invalid number of wires");
QFailIf(!isValidQubits(wires), "Invalid given wires to measure");
RT_FAIL_IF(numWires > numQubits, "Invalid number of wires");
RT_FAIL_IF(!isValidQubits(wires), "Invalid given wires to measure");

auto dev_wires = getDeviceWires(wires);

Expand Down Expand Up @@ -319,8 +320,8 @@ auto LightningKokkosSimulator::PartialSample(const std::vector<QubitIdType> &wir
const size_t numWires = wires.size();
const size_t numQubits = this->GetNumQubits();

QFailIf(numWires > numQubits, "Invalid number of wires");
QFailIf(!isValidQubits(wires), "Invalid given wires to measure");
RT_FAIL_IF(numWires > numQubits, "Invalid number of wires");
RT_FAIL_IF(!isValidQubits(wires), "Invalid given wires to measure");

// get device wires
auto &&dev_wires = getDeviceWires(wires);
Expand Down Expand Up @@ -387,8 +388,8 @@ auto LightningKokkosSimulator::PartialCounts(const std::vector<QubitIdType> &wir
const size_t numWires = wires.size();
const size_t numQubits = this->GetNumQubits();

QFailIf(numWires > numQubits, "Invalid number of wires");
QFailIf(!isValidQubits(wires), "Invalid given wires to measure");
RT_FAIL_IF(numWires > numQubits, "Invalid number of wires");
RT_FAIL_IF(!isValidQubits(wires), "Invalid given wires to measure");

// get device wires
auto &&dev_wires = getDeviceWires(wires);
Expand Down Expand Up @@ -502,9 +503,9 @@ auto LightningKokkosSimulator::Gradient(const std::vector<size_t> &trainParams)
bool is_valid_measurements =
std::all_of(obs_callees.begin(), obs_callees.end(),
[](const auto &m) { return m == Lightning::Measurements::Expval; });
QFailIf(!is_valid_measurements,
"Unsupported measurements to compute gradient; "
"Adjoint differentiation method only supports expectation return type");
RT_FAIL_IF(!is_valid_measurements,
"Unsupported measurements to compute gradient; "
"Adjoint differentiation method only supports expectation return type");

// Create OpsData
auto &&ops_names = this->cache_manager.getOperationsNames();
Expand Down
2 changes: 1 addition & 1 deletion runtime/lib/backend/LightningKokkosSimulator.hpp
Expand Up @@ -21,7 +21,6 @@ throw std::logic_error("StateVectorKokkos.hpp: No such header file");
#define __device_lightning_kokkos

#include <bitset>
#include <cassert>
#include <cmath>
#include <cstdint>
#include <iostream>
Expand All @@ -35,6 +34,7 @@ throw std::logic_error("StateVectorKokkos.hpp: No such header file");
#include "StateVectorKokkos.hpp"

#include "CacheManager.hpp"
#include "Exception.hpp"
#include "LightningUtils.hpp"
#include "ObsManager.hpp"
#include "QuantumDevice.hpp"
Expand Down
37 changes: 19 additions & 18 deletions runtime/lib/backend/LightningSimulator.cpp
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

#include "LightningSimulator.hpp"
#include "BaseUtils.hpp"

namespace Catalyst::Runtime::Simulator {

Expand Down Expand Up @@ -58,7 +57,7 @@ auto LightningSimulator::GetNumQubits() const -> size_t { return this->device_sv

void LightningSimulator::StartTapeRecording()
{
QFailIf(this->cache_recording, "Cannot re-activate the cache manager");
RT_FAIL_IF(this->cache_recording, "Cannot re-activate the cache manager");
this->cache_recording = true;
this->cache_manager.Reset();
}
Expand Down Expand Up @@ -117,8 +116,8 @@ void LightningSimulator::NamedOperation(const std::string &name, const std::vect
Lightning::lookup_gates(Lightning::simulator_gate_info, name);

// Check the validity of number of qubits and parameters
QFailIf((!wires.size() && wires.size() != op_num_wires), "Invalid number of qubits");
QFailIf(params.size() != op_num_params, "Invalid number of parameters");
RT_FAIL_IF((!wires.size() && wires.size() != op_num_wires), "Invalid number of qubits");
RT_FAIL_IF(params.size() != op_num_params, "Invalid number of parameters");

// Convert wires to device wires
auto &&dev_wires = getDeviceWires(wires);
Expand Down Expand Up @@ -146,8 +145,8 @@ void LightningSimulator::MatrixOperation(const std::vector<std::complex<double>>
auto LightningSimulator::Observable(ObsId id, const std::vector<std::complex<double>> &matrix,
const std::vector<QubitIdType> &wires) -> ObsIdType
{
QFailIf(wires.size() > this->GetNumQubits(), "Invalid number of wires");
QFailIf(!isValidQubits(wires), "Invalid given wires");
RT_FAIL_IF(wires.size() > this->GetNumQubits(), "Invalid number of wires");
RT_FAIL_IF(!isValidQubits(wires), "Invalid given wires");

auto &&dev_wires = getDeviceWires(wires);

Expand All @@ -171,7 +170,8 @@ auto LightningSimulator::HamiltonianObservable(const std::vector<double> &coeffs

auto LightningSimulator::Expval(ObsIdType obsKey) -> double
{
QFailIf(!this->obs_manager.isValidObservables({obsKey}), "Invalid key for cached observables");
RT_FAIL_IF(!this->obs_manager.isValidObservables({obsKey}),
"Invalid key for cached observables");
auto &&obs = this->obs_manager.getObservable(obsKey);

// update tape caching
Expand All @@ -186,7 +186,8 @@ auto LightningSimulator::Expval(ObsIdType obsKey) -> double

auto LightningSimulator::Var(ObsIdType obsKey) -> double
{
QFailIf(!this->obs_manager.isValidObservables({obsKey}), "Invalid key for cached observables");
RT_FAIL_IF(!this->obs_manager.isValidObservables({obsKey}),
"Invalid key for cached observables");

auto &&obs = this->obs_manager.getObservable(obsKey);

Expand Down Expand Up @@ -228,8 +229,8 @@ auto LightningSimulator::PartialProbs(const std::vector<QubitIdType> &wires) ->
const size_t numWires = wires.size();
const size_t numQubits = this->GetNumQubits();

QFailIf(numWires > numQubits, "Invalid number of wires");
QFailIf(!isValidQubits(wires), "Invalid given wires to measure");
RT_FAIL_IF(numWires > numQubits, "Invalid number of wires");
RT_FAIL_IF(!isValidQubits(wires), "Invalid given wires to measure");

auto dev_wires = getDeviceWires(wires);
Pennylane::Simulators::Measures m{*(this->device_sv)};
Expand Down Expand Up @@ -273,8 +274,8 @@ auto LightningSimulator::PartialSample(const std::vector<QubitIdType> &wires, si
const size_t numWires = wires.size();
const size_t numQubits = this->GetNumQubits();

QFailIf(numWires > numQubits, "Invalid number of wires");
QFailIf(!isValidQubits(wires), "Invalid given wires to measure");
RT_FAIL_IF(numWires > numQubits, "Invalid number of wires");
RT_FAIL_IF(!isValidQubits(wires), "Invalid given wires to measure");

// get device wires
auto &&dev_wires = getDeviceWires(wires);
Expand Down Expand Up @@ -352,8 +353,8 @@ auto LightningSimulator::PartialCounts(const std::vector<QubitIdType> &wires, si
const size_t numWires = wires.size();
const size_t numQubits = this->GetNumQubits();

QFailIf(numWires > numQubits, "Invalid number of wires");
QFailIf(!isValidQubits(wires), "Invalid given wires to measure");
RT_FAIL_IF(numWires > numQubits, "Invalid number of wires");
RT_FAIL_IF(!isValidQubits(wires), "Invalid given wires to measure");

// get device wires
auto &&dev_wires = getDeviceWires(wires);
Expand Down Expand Up @@ -461,9 +462,9 @@ auto LightningSimulator::Gradient(const std::vector<size_t> &trainParams)
bool is_valid_measurements =
std::all_of(obs_callees.begin(), obs_callees.end(),
[](const auto &m) { return m == Lightning::Measurements::Expval; });
QFailIf(!is_valid_measurements,
"Unsupported measurements to compute gradient; "
"Adjoint differentiation method only supports expectation return type");
RT_FAIL_IF(!is_valid_measurements,
"Unsupported measurements to compute gradient; "
"Adjoint differentiation method only supports expectation return type");

auto &&state = this->device_sv->getDataVector();

Expand Down Expand Up @@ -507,7 +508,7 @@ auto LightningSimulator::Gradient(const std::vector<size_t> &trainParams)
std::vector<std::vector<double>> results(num_observables);
auto begin_loc_iter = jacobian_t.begin();
for (size_t obs_idx = 0; obs_idx < num_observables; obs_idx++) {
assert(begin_loc_iter != jacobian_t.end());
RT_ASSERT(begin_loc_iter != jacobian_t.end());
results[obs_idx].insert(results[obs_idx].begin(), begin_loc_iter,
begin_loc_iter + num_train_params);
begin_loc_iter += num_train_params;
Expand Down

0 comments on commit c9032f5

Please sign in to comment.