Skip to content

Commit

Permalink
Merge pull request #2 from vtomole/fix_alpha
Browse files Browse the repository at this point in the history
Fix `DiracNotation(alpha*<0|0>*|1>).operate_reduce() ` failure
  • Loading branch information
vtomole committed Mar 20, 2023
2 parents ecf6067 + 68685b5 commit 56549ad
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
2 changes: 1 addition & 1 deletion dev-requirements.txt
@@ -1 +1 @@
general-superstaq[dev]~=0.3.9
general-superstaq[dev]~=0.3.21
31 changes: 26 additions & 5 deletions symboliq/dirac_notation.py
@@ -1,6 +1,7 @@
from typing import Any, List
from typing import Any, List, Union

import sympy
from sympy import Symbol, srepr
from sympy.core.numbers import Half
from sympy.physics.quantum import (
Bra,
Expand Down Expand Up @@ -39,7 +40,7 @@
cx = TensorProduct(b_0, i) + TensorProduct(b_3, x)


def qapply(expr: sympy.Expr) -> sympy.Expr:
def qapply(expr: sympy.Expr) -> Union[sympy.Basic, sympy.Expr]:
return DiracNotation(expr).operate_reduce()


Expand Down Expand Up @@ -173,15 +174,28 @@ def _gate_reduce(self, arg: sympy.Expr, add_step: bool) -> sympy.Expr:
return self._add_reduce(arg, add_step)
return self._tensor_reduce(arg, add_step)

def operate_reduce(self) -> sympy.Expr:
def operate_reduce(self) -> Union[sympy.Basic, sympy.Expr]:
"""Iterates through an expression and simplifies incrementally while
keeping track of the steps that are taken to simplify it
Returns:
The simplified expression
"""
assert isinstance(self._expr, sympy.Mul)
rev_args = self._expr.args[::-1]

if "complex=True" in srepr(self._expr):
expr = sympy.physics.quantum.qapply(self._expr)
else:
expr = self._expr
if isinstance(expr, sympy.Mul):
return self.handle_mul(expr)
elif isinstance(expr, sympy.Add):
state = sympy.Integer(0)
for j in expr.args:
state = state + self.handle_mul(j)
return state

def handle_mul(self, expr: sympy.Expr) -> Union[sympy.Basic, sympy.Expr]:
rev_args = expr.args[::-1]
state = rev_args[0]
for i in range(1, len(rev_args)):
rev_args_by_index = rev_args[i]
Expand All @@ -192,6 +206,13 @@ def operate_reduce(self) -> sympy.Expr:
state = self._gate_reduce(base * state, True)

else:
if (
isinstance(rev_args_by_index, Symbol)
and isinstance(state, Ket)
and "complex=True" in srepr(rev_args_by_index)
):
return rev_args_by_index * state
assert hasattr(rev_args_by_index, "__mul__")
state = self._gate_reduce(rev_args_by_index * state, True)
return state

Expand Down
11 changes: 10 additions & 1 deletion symboliq/dirac_notation_test.py
@@ -1,4 +1,4 @@
from sympy import sqrt
from sympy import Symbol, sqrt
from sympy.physics.quantum import TensorProduct
from sympy.physics.quantum.gate import HadamardGate, IdentityGate, XGate, YGate, ZGate
from sympy.physics.quantum.qubit import Qubit
Expand Down Expand Up @@ -144,3 +144,12 @@ def test_get_steps_latex() -> None:
== r"(0) \quad B_{0} {\left|0\right\rangle } \\(1) \quad \left\langle 0 \right. "
r"{\left|0\right\rangle } {\left|0\right\rangle } \\(2) \quad {\left|0\right\rangle } \\"
)


def test_assert_x_gate_on_alpha_and_beta() -> None:
alpha = Symbol("alpha", complex=True)
beta = Symbol("beta", complex=True)

state = x * (alpha * ket_0 + beta * ket_1)

assert str(DiracNotation(state).operate_reduce()) == "alpha*|1> + beta*|0>"

0 comments on commit 56549ad

Please sign in to comment.