In [1]:
import jax.numpy as np
from jax.scipy.linalg import expm
from jax import jit
import numpy as onp
from functools import partial

In [31]:
class Gate:
    
    def __init__(self, dim, qids, n_params=None):
        self.dim = dim
        self.qudit_ids = qids
        self.n_params = n_params
        
    def gate(self, params):
        return np.eye(self.dim)

class RotGate(Gate):
    
    def __init__(self, qids):
        super().__init__(dim=2, qids=qids, n_params=3)
        self.param_ind = None

    def gate(self, params):
        ind = self.param_ind
        theta, phi, rotangle = params[ind:ind+self.n_params]
        s = np.sin(rotangle/2)
        c = np.cos(rotangle/2)
        Rx = np.array([[c, 1j*s],
                       [1j*s, c]])
        Ry = np.array([[c, s],
                       [-s ,c]])
        Rz = np.array([[c+1j*s, 0],
                       [0, c-1j*s]])
        return np.sin(theta)*np.cos(phi)*Rx \
            + np.sin(theta)*np.sin(phi)*Ry \
            + np.cos(theta)*Rz
    
class DispGate(Gate):
    
    def __init__(self, qids, dim):
        super().__init__(dim, qids, n_params=1)
        self.param_ind = None
        self.cr = onp.zeros((self.dim, self.dim))
        self.an = onp.zeros((self.dim, self.dim))
        for i in range(self.dim-1):
            self.cr[i+1, i] = onp.sqrt(i+1)
            self.an[i, i+1] = onp.sqrt(i+1)
    
    def gate(self, params):
        a = params[self.param_ind]
        astar = a.conjugate()
        return expm(a*self.cr - astar*self.an)

class SNAPGate(Gate):
    
    def __init__(self, qubit_id, cavity_id, dim):
        qids = [qubit_id, cavity_id]
        super().__init__(dim, qids, n_params=dim//2)
        self.param_ind = None
    
    def gate(self, params):
        ind = self.param_ind
        theta = params[ind:ind+self.n_params]
        diag = np.exp(np.concatenate((1j*theta, -1j*theta)))
        return np.diag(diag)        

class GateLayer(Gate):
    
    def __init__(self, gates, regInfo):
        self.gates = sorted(gates, key=lambda x: x.qudit_ids[0])
        qids = []
        for g in self.gates:
            qids += g.qudit_ids
        assert len(set(qids)) == len(qids)
        self.regInfo = regInfo
        self.permuted = qids
        super().__init__(self.regInfo.dim, qids)
    
    def gate(self, params):
        mat = 1
        for g in self.gates:
            g = g.gate(params)
            mat = np.kron(mat, g)
        mat = mat.reshape(self.regInfo.shape)
        if self.permuted != self.regInfo.ids:
            mat = np.moveaxis(mat, self.permuted, self.regInfo.ids)
        return mat.reshape((self.regInfo.dim, self.regInfo.dim))

class RegisterInfo:
    
    def __init__(self, register):
        self.register = list(register)
        self.dim = np.prod(register)
        self.shape = self.register + self.register
        self.ids = range(len(self.register))

class Circuit:
    
    def __init__(self, register):
        self.regInfo = RegisterInfo(register)
        self.gates = []
        self.n_params = None
        self.assembled = False
        
    def add_gate(self, gate):
        dim = 1
        for qid in gate.qudit_ids:
            assert qid < len(self.regInfo.register)
            dim *= self.regInfo.register[qid]
        assert dim == gate.dim
        self.gates.append(gate)

    def assemble(self):
        def complete(layer, layers):
            unused = used_ids ^ set(self.regInfo.ids)
            for qid in unused:
                dim = self.regInfo.register[qid]
                layer.append(Gate(dim, [qid], 0))
            layers.append(GateLayer(gates=layer, regInfo=self.regInfo))
        layers = []
        layer = []
        param_ind = 0
        used_ids = set()
        for g in self.gates:
            g.param_ind = param_ind
            param_ind += g.n_params
            g_ids = set(g.qudit_ids)
            if len(g_ids.intersection(used_ids)) == 0:
                layer.append(g)
                used_ids = used_ids.union(g_ids)
            else:
                complete(layer, layers)
        complete(layer, layers)
        self.gates = layers
        self.n_params = param_ind
        self.assembled = True

    @partial(jit, static_argnums=(0,))
    def evaluate(self, params):
        assert self.assembled
        mat = np.eye(self.regInfo.dim)
        for layer in self.gates:
            g = layer.gate(params)
            mat = np.matmul(mat, g)
        return mat
        

In [34]:
c = Circuit([2, 10])
c.add_gate(RotGate([0]))
c.add_gate(DispGate([1], 10))
c.add_gate(SNAPGate(0, 1, 20))
c.add_gate(RotGate([0]))
c.add_gate(SNAPGate(0, 1, 20))
c.add_gate(DispGate([1], 10))
c.add_gate(SNAPGate(0, 1, 20))
c.assemble()
print(c.n_params)
print(len(c.gates))
params = onp.random.rand(c.n_params)
print(c.evaluate(params))

2
10
20
2
20
10
20
38
6
[[ 0.02405502-0.12600096j  0.09758031-0.5111213j   0.19681476-1.0309016j
   0.4119033 -2.1575222j   0.47369167-2.4811628j   0.7634956 -3.9991374j
   0.53175074-2.7852755j   0.9179621 -4.808218j    0.31299484-1.6394482j
   0.93277603-4.8858166j  -0.15481788-0.11050272j -0.6280256 -0.44825685j
  -1.2666918 -0.9041064j  -2.650994  -1.892158j   -3.0486562 -2.1759913j
  -4.9138246 -3.507262j   -3.4223266 -2.4427004j  -5.90796   -4.2168303j
  -2.0144246 -1.4378037j  -6.003307  -4.2848864j ]
 [-0.09758073+0.5111213j  -0.2542832 +1.3319162j  -0.5754374 +3.014104j
  -0.60648805+3.176749j   -0.88341916+4.6272964j  -0.24331169+1.274457j
  -0.5585236 +2.925507j    0.5215972 -2.732087j   -0.20194113+1.0577543j
   0.9389842 -4.9183426j   0.6280267 +0.44825616j  1.6365551 +1.1680976j
   3.703494  +2.6433847j   3.9033403 +2.786026j    5.685657  +4.0581617j
   1.5659537 +1.1177043j   3.5946329 +2.565685j   -3.3569736 -2.3960536j
   1.2996852 +0.92765504j -6.0432725 -4.313412j  ]