In [14]:
import torch
from math import sqrt
import numpy as np
import torch.multiprocessing as mp
import os
import torch.distributed as dist
import time

In [15]:
def kronecker(A, B):
    if not isinstance(A, torch.Tensor):
        return B
    return torch.einsum("ab,cd->acbd", A, B).view(A.size(0)*B.size(0),  A.size(1)*B.size(1))

Z = torch.tensor([
    [1.0, 0],
    [0, -1.0]
], dtype=torch.cfloat)

X = torch.tensor([
    [0, 1.0],
    [1.0,  0]
], dtype=torch.cfloat)

Y = torch.tensor([
    [0, -1.0j],
    [1.0j, 0]
], dtype=torch.cfloat)

H = math.sqrt(2)/2*torch.tensor([
    [1, 1],
    [1, -1]
], dtype=torch.cfloat)

I = torch.eye(2, dtype=torch.cfloat)

def RX(phi):
    gate = torch.zeros((2, 2), dtype=torch.cfloat)
    gate[0, 0], gate[0, 1], gate[1, 0], gate[1, 1] = torch.cos(phi/2), torch.complex(torch.tensor(0, dtype=torch.float32), -torch.sin(phi/2)), torch.complex(torch.tensor(0, dtype=torch.float32), -torch.sin(phi/2)), torch.cos(phi/2)
    return gate.to(phi.device)

def RY(phi):
    gate = torch.zeros((2, 2), dtype=torch.cfloat)
    gate[0, 0], gate[0, 1], gate[1, 0], gate[1, 1] = torch.cos(phi/2), -torch.sin(phi/2), torch.sin(phi/2), torch.cos(phi/2)
    return gate.to(phi.device)

def RZ(phi):
    gate = torch.zeros((2, 2), dtype=torch.cfloat)
    gate[0, 0], gate[1, 1] = torch.complex(torch.cos(phi/2), -torch.sin(phi/2)), torch.complex(torch.cos(phi/2), torch.sin(phi/2))
    return gate.to(phi.device)

def Rot(alpha, beta, theta):
    return torch.mm(torch.mm(RZ(alpha), RY(beta)), RZ(theta))

def Frax(n):
    axis = n / torch.norm(n)
    gate = torch.zeros((2, 2), dtype=torch.cfloat)
    gate[0, 0], gate[0, 1], gate[1, 0], gate[1, 1] = -axis[2]*1j, -axis[0]*1j-axis[1], -axis[0]*1j+axis[1], axis[2]*1j
    return gate.to(n.device)

def State00():
    state = torch.zeros((2, 2), dtype=torch.cfloat)
    state[0, 0] = 1
    return state

def State11():
    state = torch.zeros((2, 2), dtype=torch.cfloat)
    state[1, 1] = 1
    return state

def CNOT(wires=[0, 1]):
    # First qubit : Cnntrol, Second qubit : Flipped
    if wires[0] < wires[1]:
        return torch.tensor([
            [1,0,0,0],
            [0,1,0,0],
            [0,0,0,1],
            [0,0,1,0]], dtype=torch.cfloat)
    else:
        return torch.tensor([
            [1,0,0,0],
            [0,0,0,1],
            [0,0,1,0],
            [0,1,0,0]], dtype=torch.cfloat)
    
RGate = {
    'RX': RX,
    'RY': RY,
    'RZ': RZ
}

CZ = torch.tensor([
    [1,0,0,0],
    [0,1,0,0], 
    [0,0,1,0], 
    [0,0,0,-1]], dtype=torch.cfloat)

def CRR(phi, wires=[0, 1], name='RX'):
    if wires[0] < wires[1]:
        first, second = State00(), State11()
        for i in range(wires[0], wires[1]):
            if i == wires[1] - 1:
                first = kronecker(first, torch.eye(2))
                second = kronecker(second, RGate[name](phi))
            else:
                first = kronecker(first, torch.eye(2))
                second = kronecker(second, torch.eye(2))
        return first + second
    else:
        first, second = torch.eye(2), RGate[name](phi)
        for i in range(wires[1], wires[0]):
            if i == wires[0] - 1:
                first = kronecker(first, State00)
                second = kronecker(second, State11)
            else:
                first = kronecker(first, torch.eye(2))
                second = kronecker(second, torch.eye(2))
        return first + second

In [16]:
def CZ_layer(n_qubits):
    if n_qubits == 2:
        return CZ
    gate1 = CZ
    for i in range(2, n_qubits, 2):
        if i+1 < n_qubits:
            gate1 = kronecker(gate1, CZ)
        else:
            gate1 = kronecker(gate1, I)
    gate2 = CZ
    gate2 = kronecker(I, gate2)
    for i in range(3, n_qubits, 2):
        if i+1 < n_qubits:
            gate2 = kronecker(gate2, CZ)
        else:
            gate2 = kronecker(gate2, I)
    return torch.mm(gate2, gate1)

def Frax_ansatz(n_qubits, param):
    # param : torch.Tensor of (n_qubits, 3)
    x = 1
    for i in range(n_qubits):
        x = kronecker(x, Frax(param[i]))
    return torch.mm(CZ_layer(n_qubits), x)

def replace_Frax_ansatz(n_qubits, measured_qubit, observable, param):
    x = 1
    for i in range(measured_qubit):
        x = kronecker(x, Frax(param[i]))
        
    if observable == 'X':
        x = kronecker(x, X)
    elif observable == 'Y':
        x = kronecker(x, Y)
    elif observable == 'Z':
        x = kronecker(x, Z)
    elif observable == 'XY':
        x = kronecker(x, (X+Y)/sqrt(2))
    elif observable == 'XZ':
        x = kronecker(x, (X+Z)/sqrt(2))
    elif observable == 'YZ':
        x = kronecker(x, (Y+Z)/sqrt(2))

    for i in range(measured_qubit+1, n_qubits):
        x = kronecker(x, Frax(param[i]))
    return torch.mm(CZ_layer(n_qubits), x)

In [17]:
def amplitude_embedding(feat, n_qubits):
    # feat : torch.tensor of 2^n_qubits elements
    if feat.ndim == 1:
        feat = feat.reshape(-1,).to(torch.complex64)
        feat /= torch.norm(feat)
    elif feat.ndim == 2:
        feat = feat.reshape(-1, 2**n_qubits,).to(torch.complex64)
        feat = feat.transpose(0,1) / torch.norm(feat, dim=1)
        feat = feat.transpose(0,1)
    return feat

In [22]:
def lastbit_Z(state):
    return torch.sum(state[0:len(state):2].abs()**2)-torch.sum(state[1:len(state):2].abs()**2)

class FraxClassify():
    def __init__(self, n_qubits, layer_size, measure_iter, world_size):
        self.n_qubits = n_qubits
        self.layer_size = layer_size
        self.measure_iter = measure_iter
        self.params = (torch.zeros(layer_size, n_qubits, 3) + 1/sqrt(3)).to(torch.complex64)
        self.world_size =world_size
        self.train_acc = []
        self.test_acc = []
        
    def fit(self, train):
        params = self.params
        train_feat, train_label = train
        x = amplitude_embedding(train_feat, self.n_qubits)
        for a in range(self.layer_size):
            for b in range(self.n_qubits):
                R = torch.zeros(3,3)
                for c in range(train_feat.shape[0]):
                    y = x[c]
                    for d in range(a):
                        y = Frax_ansatz(self.n_qubits, params[d]) @ y
                    rx = replace_Frax_ansatz(self.n_qubits, b, 'X', params[a]) @ y
                    ry = replace_Frax_ansatz(self.n_qubits, b, 'Y', params[a]) @ y
                    rz = replace_Frax_ansatz(self.n_qubits, b, 'Z', params[a]) @ y
                    rxy = replace_Frax_ansatz(self.n_qubits, b, 'XY', params[a]) @ y
                    rxz = replace_Frax_ansatz(self.n_qubits, b, 'XZ', params[a]) @ y
                    ryz = replace_Frax_ansatz(self.n_qubits, b, 'YZ', params[a]) @ y
                    for d in range(a+1, self.layer_size):
                        rx = Frax_ansatz(self.n_qubits, params[d]) @ rx
                        ry = Frax_ansatz(self.n_qubits, params[d]) @ ry       
                        rz = Frax_ansatz(self.n_qubits, params[d]) @ rz
                        rxy = Frax_ansatz(self.n_qubits, params[d]) @ rxy
                        rxz = Frax_ansatz(self.n_qubits, params[d]) @ rxz        
                        ryz = Frax_ansatz(self.n_qubits, params[d]) @ ryz
                        
                    rx = lastbit_Z(rx)
                    ry = lastbit_Z(ry)
                    rz = lastbit_Z(rz)
                    rxy = lastbit_Z(rxy)
                    rxz = lastbit_Z(rxz)
                    ryz = lastbit_Z(ryz)
                        
                    R[0,0] += train_label[c] * 2 * rx
                    R[0,1] += train_label[c] * (2 * rxy-rx-ry)
                    R[0,2] += train_label[c] * (2 * rxz-rx-rz)
                    R[1,1] += train_label[c] * 2 * ry
                    R[1,2] += train_label[c] * (2 * ryz-ry-rz)
                    R[2,1] += train_label[c] * 2 * rz
                    
                R[1,0] = R[0,1]
                R[2,0] = R[0,2]
                R[2,1] = R[1,2]
                group = dist.new_group(range(self.world_size))
                dist.all_reduce(R, op=dist.ReduceOp.SUM, group=group)
                if (dist.get_rank(group) == 0):
                    print(R)
                eigenvalues, eigenvectors = torch.linalg.eig(R)
                self.params[a, b] = eigenvectors[torch.argmax(eigenvalues.real)]
                self.params[a, b] /= torch.norm(self.params[a, b])
                
    def eval(self, train, test):
        test_score = 0
        train_score = 0
        train_feat, train_label = train
        test_feat, test_label = test
        train_size = train_label.shape[0]
        test_size = test_label.shape[0]
        for a in range(test_size):
            x = amplitude_embedding(test_feat[a], self.n_qubits)
            for b in range(self.layer_size):
                x = Frax_ansatz(self.n_qubits, self.params[b]) @ x
            test_score += test_label[a] * lastbit_Z(x)
        group = dist.new_group(range(self.world_size))
        dist.all_reduce(test_score, op=dist.ReduceOp.SUM, group=group)
        self.test_acc.append(test_score)
        
        for a in range(train_size):
            x = amplitude_embedding(train_feat[a], self.n_qubits)
            for b in range(self.layer_size):
                x = Frax_ansatz(self.n_qubits, self.params[b]) @ x
            train_score += train_label[a] * lastbit_Z(x)
        group = dist.new_group(range(self.world_size))
        dist.all_reduce(train_score, op=dist.ReduceOp.SUM, group=group)
        self.train_acc.append(train_score)
        
    def get_accuracy(self):
        if dist.get_rank() == 0:
            print(self.train_acc, self.test_acc)

In [19]:
def data_loader():
    try:
        test_label = torch.from_numpy(np.load('drive/MyDrive/mnist_test_Label.npy'))
        train_label = torch.from_numpy(np.load('drive/MyDrive/mnist_train_Label.npy'))
        test_feat = torch.from_numpy(np.load('drive/MyDrive/mnist_test_feat.npy'))
        train_feat = torch.from_numpy(np.load('drive/MyDrive/mnist_train_feat.npy'))
        return test_label, train_label, test_feat, train_feat
    except Exception as e:
        print(e)
    
def cut_data(train_label, train_feat, test_label, test_feat, rank, world_size):
    data_len_min = len(train_feat) // world_size
    offset = len(train_feat) % world_size
    if rank < offset:
        start1 = rank*(data_len_min+1)
        end1 = start1+data_len_min+1
    else:
        start1 = offset*(data_len_min+1)+(rank-offset)*data_len_min
        end1 = start1+data_len_min
    data_len_min = len(test_feat) // world_size
    offset = len(test_feat) % world_size
    if rank < offset:
        start2 = rank*(data_len_min+1)
        end2 = start2+data_len_min+1
    else:
        start2 = offset*(data_len_min+1)+(rank-offset)*data_len_min
        end2 = start2+data_len_min
    
    return train_label[start1:end1], train_feat[start1:end1], test_label[start2:end2], test_feat[start2:end2]

def parallel_train(rank, world_size, layer_size, update_iter, measure_iter):
    print('I am ', rank)
    n_qubits = 6
    test_label, train_label, test_feat, train_feat = data_loader()
    train_label, train_feat, test_label, test_feat = cut_data(train_label, train_feat, test_label, test_feat, rank, world_size)
    model = FraxClassify(n_qubits, layer_size, measure_iter, world_size)
    for i in range(update_iter):
        model.fit(train=(train_feat, train_label))
        model.eval(train=(train_feat, train_label), test=(test_feat, test_label))
    model.get_accuracy()

In [23]:
def init_process(rank, world_size, layer_size, update_iter, measure_iter, fn, backend='gloo'):
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=world_size)
    fn(rank, world_size, layer_size, update_iter, measure_iter)

if __name__ == '__main__':
    print("hello")
    W = 4
    L = 2
    N = 6
    processes = []
    st = time.time()
    for rank in range(W):
        p = mp.Process(target=init_process, args=(rank, W, L, N, 100, parallel_train))
        p.start()
        processes.append(p)
        
    for p in processes:
        p.join()
        
    print('Implementation time : ', time.time()-st)

hello
I am I am I am I am    2 01
3


tensor([[-5.3357e+01, -6.1393e-06, -5.5730e-06],
        [-6.1393e-06, -5.3357e+01, -7.4804e-06],
        [-5.5730e-06, -7.4804e-06,  0.0000e+00]])
tensor([[-5.3357e+01, -4.1127e-06, -6.6757e-06],
        [-4.1127e-06, -5.3357e+01, -2.4438e-06],
        [-6.6757e-06, -2.4438e-06,  0.0000e+00]])
tensor([[-5.3357e+01, -1.2070e-05, -9.1791e-06],
        [-1.2070e-05, -5.3357e+01, -1.0043e-05],
        [-9.1791e-06, -1.0043e-05,  0.0000e+00]])
tensor([[-5.3357e+01, -8.1956e-06, -6.4373e-06],
        [-8.1956e-06, -5.3357e+01, -9.7454e-06],
        [-6.4373e-06, -9.7454e-06,  0.0000e+00]])
tensor([[-1.5965e+02, -1.7494e-05,  3.2973e-01],
        [-1.7494e-05, -1.5965e+02,  3.5117e-01],
        [ 3.2973e-01,  3.5117e-01,  0.0000e+00]])
tensor([[18.1630, -0.8389, -2.0307],
        [-0.8389, 15.6359,  7.7572],
        [-2.0307,  7.7572,  0.0000]])
tensor([[ 1.4280e+01,  1.1921e-07, -1.7583e-06],
        [ 1.1921e-07,  1.4280e+01,  2.4140e-06],
        [-1.