# Monkey Patching Amazon Braket: Extending the `Circuit` Class

In this notebook, we add customer capabilities to the Amazon Braket `Circuit` class by building custom functions to operate on quantum circuits defined in Braket's SDK. We use a technique called Monkey Patching to add those functions to Braket's `Circuit` class at runtime, thereby extending Braket's `Circuit` functionality. We work with two examples: a function we call `get_unitary()`, which returns the unitary matrix corresponding to a `Circuit` object, and a function called `adjoint()`, which returns a new circuit object corresponding to the conjugate transpose of the initial. The functions in the module are defined such that they act on `self`, and can be added to the `Circuit` class as methods. Amazon Braket already includes a method called `as_unitary()`, which allows us to compare the built-in functionality with our own added method.

### Table of Contents
* [Introduction](#introduction)
 * [Braket's `Circuit` class](#braketcircuit)
* [Monkey patching `get_unitary` and `adjoint` to the `Circuit` class](#monkeypatch)
* [Testing `get_unitary`](#testgetunitary)
* [Testing `adjoint`](#testadjoint)
* [Appendix](#appendix)
 * [Intro to Monkey Patching](#introtomonkeypatch)
 * [Custom function definitions](#functiondefinitions)
* [References](#references)

## Introduction<a name="introduction"></a>

When working with a library, we may occasionally need to extend or override the functionality defined in some module without editing files on disk or overwriting the library. We can accomplish this using a technique called _Monkey Patching_ [[1]](#References), which is a way of dynamically replacing or modifying code at runtime. The interested reader can find a short introduction to Monkey Patching in the [Appendix](#Appendix).

### Braket's `Circuit` class<a name="braketcircuit"></a>

Let's use the principle of monkey patching to expand the things we can do with Braket. As of the time of writing, there are no built in methods in Braket's `Circuit` class to export the unitary matrix of a circuit object. Similarly, there are no methods to obtain the adjoint of a circuit. In order to overcome this challenge, we will define our own custom functions and monkey patch them onto the `Circuit` class.

The `get_unitary` and `adjoint` functions are defined in the `utils_circuit.py` module, and their defintions are shown in the [Appendix](#Appendix) for completeness.

In [1]:
# AWS imports: Import Braket SDK modules
from braket.circuits import Circuit, circuit

# Local imports:
from utils_circuit import get_unitary, adjoint

## Monkey patching `get_unitary` and `adjoint` to the `Circuit` class<a name="monkeypatch"></a>

Let's now attach the our functions to the `Circuit` class as built in methods.

In [2]:
Circuit.get_unitary = get_unitary
Circuit.adjoint = adjoint

## Testing `get_unitary`<a name="testgetunitary"></a>

The `get_unitary` method is just a way of exporting a unitary matrix corresponding to the circuit. This simply returns a matrix and does not modify the circuit. As such, it plays nicely with the rest of the SDK. We can compare the output with that of the built-in method `as_unitary`, which does the same thing as `get_unitary` to ensure that our code works as expected.

In [3]:
# Print Pauli X
circ1 = Circuit().x(0)
print("get_unitary:\n", circ1.get_unitary().real)
print("as_unitary:\n", circ1.as_unitary().real)

get_unitary:
 [[0. 1.]
 [1. 0.]]
as_unitary:
 [[0. 1.]
 [1. 0.]]


In [4]:
# Print CNOT
circ2 = Circuit().cnot(0,1)
print("get_unitary:\n", circ2.get_unitary().real)
print("as_unitary:\n", circ2.as_unitary().real)

get_unitary:
 [[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 0. 1.]
 [0. 0. 1. 0.]]
as_unitary:
 [[1. 0. 0. 0.]
 [0. 0. 0. 1.]
 [0. 0. 1. 0.]
 [0. 1. 0. 0.]]


In [5]:
# Print SWAP
circ3 = Circuit().swap(0,1)
print("get_unitary:\n", circ3.get_unitary().real)
print("as_unitary:\n", circ3.as_unitary().real)

get_unitary:
 [[1. 0. 0. 0.]
 [0. 0. 1. 0.]
 [0. 1. 0. 0.]
 [0. 0. 0. 1.]]
as_unitary:
 [[1. 0. 0. 0.]
 [0. 0. 1. 0.]
 [0. 1. 0. 0.]
 [0. 0. 0. 1.]]


In [6]:
# Print a larger yet still tractable circuit
circ4 = Circuit().z(range(3))
print("get_unitary:\n", circ4.get_unitary().real)
print("as_unitary:\n", circ4.as_unitary().real)

get_unitary:
 [[ 1.  0.  0.  0.  0.  0.  0.  0.]
 [ 0. -1.  0.  0.  0.  0.  0.  0.]
 [ 0.  0. -1.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  1.  0.  0.  0.  0.]
 [ 0.  0.  0.  0. -1.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  1.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  1.  0.]
 [ 0.  0.  0.  0.  0.  0.  0. -1.]]
as_unitary:
 [[ 1.  0.  0.  0.  0.  0.  0.  0.]
 [ 0. -1.  0.  0.  0.  0.  0.  0.]
 [ 0.  0. -1.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  1.  0.  0.  0.  0.]
 [ 0.  0.  0.  0. -1.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  1.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  1.  0.]
 [ 0.  0.  0.  0.  0.  0.  0. -1.]]


## Testing `adjoint`<a name="testadjoint"></a>

The `adjoint` method returns a new circuit that corresponds to the adjoint of the parent circuit. This new circuit is constructed using custom unitaries, defined using the adjoint of the matrix of each gate in the original circuit. Thus, the output of `adjoint` is a circuit with many custom unitary gates. Such a circuit is well defined within the SDK, but the constituent gates may not be supported by all backends. For example, at the time of writing, neither the Rigetti QPU, nor the IONQ QPU support the `Unitary` gate. Thus, circuits generated using `adjoint` should be limited to simulators for the time being.

In [7]:
# Simple circuit: XY. Adjoint should reverse the order of gates:
adj_circ1 = Circuit().x(0).y(0)
print(adj_circ1)
print('Adjoint:')
print(adj_circ1.adjoint())

T  : |0|1|
          
q0 : -X-Y-

T  : |0|1|
Adjoint:
T  : |0|1|
          
q0 : -Y-X-

T  : |0|1|


In [8]:
# Simple circuit with non-trivial unitary matrix: RxRy.
adj_circ2 = Circuit().rx(0,1).ry(0,1)
print("Original unitary matrix:")
print(adj_circ2.get_unitary().round(3))
print('Adjoint matrix:')
print(adj_circ2.adjoint().get_unitary().round(3))


Original unitary matrix:
[[ 0.77 +0.23j  -0.421-0.421j]
 [ 0.421-0.421j  0.77 -0.23j ]]
Adjoint matrix:
[[ 0.77 -0.23j   0.421+0.421j]
 [-0.421+0.421j  0.77 +0.23j ]]


In [9]:
# Larger circuit:
adj_circ3 = Circuit().s(0).cnot(0,1).t(1).h(2).cnot(0,2).x(1).h(range(3)).cy(0,2).x(2).y(0)
print(adj_circ3)
print("Adjoint:")
print(adj_circ3.adjoint())

T  : |0|1| 2 |3| 4 |5|
                      
q0 : -S-C---C-H---C-Y-
        |   |     |   
q1 : ---X-T-|-X-H-|---
            |     |   
q2 : -H-----X-H---Y-X-

T  : |0|1| 2 |3| 4 |5|
Adjoint:
T  : |0| 1 |2 |3|4|5 |
                      
q0 : -Y-C---H--C-C-Si-
        |      | |    
q1 : -H-|-X-Ti-|-X----
        |      |      
q2 : -X-Y---H--X-H----

T  : |0| 1 |2 |3|4|5 |


In [10]:
# Adjoint of circuit with custom unitary (given by diag(1, i))
import numpy as np
adj_circ4 = Circuit().unitary(matrix=np.diag([1,1j]),targets=[0])
print(adj_circ4)
print("Adjoint:")
print(adj_circ4.adjoint())

T  : |0|
        
q0 : -U-

T  : |0|
Adjoint:
T  : |0 |
         
q0 : -UH-

T  : |0 |


In [11]:
# Print the matrices:
print("Original:")
print(adj_circ4.get_unitary())
print("Adjoint:")
print(adj_circ4.adjoint().get_unitary())

Original:
[[1.+0.j 0.+0.j]
 [0.+0.j 0.+1.j]]
Adjoint:
[[1.+0.j 0.+0.j]
 [0.+0.j 0.-1.j]]


We can see from the above examples that monkey patching has allowed us to extend Braket's `Circuit` class to include the `get_unitary` and `adjoint` methods, which behave as expected.

## Appendix<a name="appendix"></a>

In [12]:
import braket._sdk as braket_sdk
print("braket_sdk version: ", braket_sdk.__version__)

braket_sdk version:  1.8.0


### Intro to Monkey Patching<a name="introtomonkeypatch"></a>

Consider a module `foo` containing a class `Bar`, which in turn has a set of methods, variables, etc. Suppose we wanted to add a method to `Bar` without directly editing the code of the class. We can do this by defining a new function `hello_world` somewhere else and _monkey patching_ the function onto `Bar` at runtime:
```python
from foo import Bar
def hello_world():
    ...
    
Bar.hello_world = hello_world # This is the monkey patch
```
The code above defines the function `hello_world`, attaches it to the `Bar` class, and allows us to call `hello_world` on an instance of the `Bar` class.
```python
foobar = Bar()
foobar.hello_world()
```

In addition to adding new methods to a class, one can also modify existing variables and attributes by overwriting them. For example:
```python
import numpy as np
np.pi = 3.0 # np.pi is now equal to 3.0
```

#### Pitfalls of Monkey Patching:
Note that using a monkey patch to modify code at runtime can lead to undesired problems (see [[1]](#References) for additional details). Since monkey patching overwrites code at runtime, poorly implemented or poorly documented code can cause unexpected behavior. For instance, if the same code is monkey patched more than once, the most recent patch overwrites any earlier patches, which can cause problems if done unintentionally. Moreover, monkey patching can dramatically alter the behavior of code relative to the unpatched version, which can be problematic if the monkey patch is not properly documented. Finally, monkey patched code need not be updated along with the code it replaces, and as such updates to a module can break a patch (or vice versa).

### Custom function definitions<a name="functiondefinitions"></a>

#### The definition of the `get_unitary` function is as follows:

```python
import numpy as np
from braket.circuits import Circuit, circuit

def get_unitary(self):
    """
    Funtion to get the unitary matrix corresponding to an entire circuit.
    Acts on self and returns the corresponding unitary
    """
    max_qubit_index = int(max(self.qubits)+1)
    num_qubits = self.qubit_count
    circ = Circuit()
    
    if num_qubits != max_qubit_index:
        circ.add_circuit(self, target=range(num_qubits))
    else:
        circ.add_circuit(self)
    
    # Define the unitary matrix. Start with the identity matrix.
    # Reshape the unitary into a tensor with the right number of indices (given by num_qubits)
    unitary = np.reshape(np.eye(2**num_qubits, 2**num_qubits), [2] * 2 * num_qubits)
    
    # Iterate over the moments in the circuit
    for key in circ.moments:
        
        # Get the matrix corresponding to the gate
        matrix = circ.moments[key].operator.to_matrix()
        # Get the target indices for the gate
        targets = circ.moments[key].target

        # Reshape the gate matrix
        gate_matrix = np.reshape(matrix, [2] * len(targets) * 2)
        
        # Construct a tuple specifying the axes along which we contract (i.e., which qubits the gate acts on)
        axes = (
            np.arange(len(targets), 2 * len(targets)),
            targets,
        )
        
        # Apply the gate by contracting the existing unitary with the new gate
        unitary = np.tensordot(gate_matrix, unitary, axes=axes)

        # tensordot causes the axes contracted to end up in the first positions.
        # We'll need to invert this permutation to put the indices in the correct place
        
        # Find the indices that are not used
        unused_idxs = [idx for idx in range(2*num_qubits) if idx not in targets]
        
        # The new order of indices is given by 
        permutation = list(targets) + unused_idxs
        
        # Find the permutation that undoes this reordering
        inverse_permutation = np.argsort(permutation)
        
        # Relabel the qubits according to this inverse_permutation
        unitary = np.transpose(unitary, inverse_permutation)

    # Reshape to a 2^N x 2^N matrix (for N=num_qubits)and return
    unitary = np.reshape(unitary, (2**num_qubits, 2**num_qubits))
    return unitary
```

#### The definition of the `adjoint` function is as follows:

```python
def adjoint(self):
    """Generates a circuit object corresponding to the adjoint of a given circuit, in which the order
    of gates is reversed, and each gate is the adjoint (i.e., conjugate transpose) of the original.
    """

    adjoint_circ = Circuit()
    
    # Loop through the instructions (gates) in the circuit:
    for instruction in self.instructions:   
        # Save the operator name and target
        op_name = instruction.operator.name
        target = instruction.target
        angle = None
        # If the operator has an attribute called 'angle', save that too
        if hasattr(instruction.operator,'angle'):
            angle = instruction.operator.angle

        # To make use of native gates, we'll define the adjoint for each
        if op_name == "H":
            adjoint_gate = Circuit().h(target)
        elif op_name == "I":
            adjoint_gate = Circuit().i(target)
        elif op_name == "X":
            adjoint_gate = Circuit().x(target)
        elif op_name == "Y":
            adjoint_gate = Circuit().y(target)
        elif op_name == "Z":
            adjoint_gate = Circuit().z(target)
        elif op_name == "S":
            adjoint_gate = Circuit().si(target)
        elif op_name == "Si":
            adjoint_gate = Circuit().s(target)
        elif op_name == "T":
            adjoint_gate = Circuit().ti(target)
        elif op_name == "Ti":
            adjoint_gate = Circuit().t(target)
        elif op_name == "V":
            adjoint_gate = Circuit().vi(target)
        elif op_name == "Vi":
            adjoint_gate = Circuit().v(target)
        elif op_name == "Rx":
            adjoint_gate = Circuit().rx(target,-angle)
        elif op_name == "Ry":
            adjoint_gate = Circuit().ry(target,-angle)
        elif op_name == "Rz":
            adjoint_gate = Circuit().rz(target,-angle)
        elif op_name == "PhaseShift":
            adjoint_gate = Circuit().phaseshift(target,-angle)
        elif op_name == "CNot":
            adjoint_gate = Circuit().cnot(*target)
        elif op_name == "Swap":
            adjoint_gate = Circuit().swap(*target)
        elif op_name == "ISwap":
            adjoint_gate = Circuit().pswap(*target,-np.pi/2)
        elif op_name == "PSwap":
            adjoint_gate = Circuit().pswap(*target,-angle)
        elif op_name == "XY":
            adjoint_gate = Circuit().xy(*target,-angle)
        elif op_name == "CPhaseShift":
            adjoint_gate = Circuit().cphaseshift(*target,-angle)
        elif op_name == "CPhaseShift00":
            adjoint_gate = Circuit().cphaseshift00(*target,-angle)
        elif op_name == "CPhaseShift01":
            adjoint_gate = Circuit().cphaseshift01(*target,-angle)
        elif op_name == "CPhaseShift10":
            adjoint_gate = Circuit().cphaseshift10(*target,-angle)
        elif op_name == "CY":
            adjoint_gate = Circuit().cy(*target)
        elif op_name == "CZ":
            adjoint_gate = Circuit().cz(*target)
        elif op_name == "XX":
            adjoint_gate = Circuit().xx(*target,-angle)
        elif op_name == "YY":
            adjoint_gate = Circuit().yy(*target,-angle)
        elif op_name == "ZZ":
            adjoint_gate = Circuit().zz(*target,-angle)
        elif op_name == "CCNot":
            adjoint_gate = Circuit().ccnot(*target)
        elif op_name == "CSwap":
            adjoint_gate = Circuit().cswap(*target)

        # If the gate is a custom unitary, we'll create a new custom unitary
        else:
            # Extract the transpose of the unitary matrix for the unitary gate
            adjoint_matrix = instruction.operator.to_matrix().T.conj()
            
            # Define a gate for which the unitary matrix is the adjoint found above.
            # Add an "H" to the display name. 
            adjoint_gate = Circuit().unitary(matrix=adjoint_matrix, targets=instruction.target, display_name="".join(instruction.operator.ascii_symbols)+"H")
            
        # Add the new gate to the adjoint circuit. Note the order of operations here:
        # (AB)^H = B^H A^H, where H is adjoint, thus we prepend new gates, rather than append.
        adjoint_circ = adjoint_gate.add(adjoint_circ)
    return adjoint_circ
```

---
## References<a name="references"></a>
[1] Wikipedia: Monkey Patch [https://en.wikipedia.org/wiki/Monkey_patch](https://en.wikipedia.org/wiki/Monkey_patch)