In [1]:
import torch.multiprocessing as mp
import torch.distributed as dist
import time
import os
import torch
import numpy as np
from math import sqrt

In [62]:
from IPython.core.displaypub import CapturingDisplayPublisher
def init_process(rank, world_size, layer_size, update_iter, fn, backend='gloo'):
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=world_size)
    fn(rank, world_size, layer_size, update_iter)
def parallel_train(rank, world_size, layer_size, update_iter):
    print('I am ', rank)
    n_qubits = 6
    train, test, train_size, test_size = data_loader(rank, world_size)
    model = FraxClassify(n_qubits, layer_size, world_size)
    acc = []
    for i in range(update_iter):
        model.fit(train=train)
        (train_acc, test_acc), (train_score, test_score) = model.eval(train=train, test=test)
        if rank == 0:
            print(train_acc / train_size, test_acc / test_size, train_score, test_score)
def data_loader(rank, world_size):
    try:
        test_label = torch.from_numpy(np.load('drive/MyDrive/mnist_test_Label.npy'))
        train_label = torch.from_numpy(np.load('drive/MyDrive/mnist_train_Label.npy'))
        test_feat = torch.from_numpy(np.load('drive/MyDrive/mnist_test_feat.npy'))
        train_feat = torch.from_numpy(np.load('drive/MyDrive/mnist_train_feat.npy'))
    except Exception as e:
        print(e)
    data_len_min = len(train_feat) // world_size
    offset = len(train_feat) % world_size
    if rank < offset:
        start1 = rank*(data_len_min+1)
        end1 = start1+data_len_min+1
    else:
        start1 = offset*(data_len_min+1)+(rank-offset)*data_len_min
        end1 = start1+data_len_min
    data_len_min = len(test_feat) // world_size
    offset = len(test_feat) % world_size
    if rank < offset:
        start2 = rank*(data_len_min+1)
        end2 = start2+data_len_min+1
    else:
        start2 = offset*(data_len_min+1)+(rank-offset)*data_len_min
        end2 = start2+data_len_min
    return (train_feat[start1:end1], train_label[start1:end1]), (test_feat[start2:end2], test_label[start2:end2]), train_label.shape[0], test_label.shape[0]
def lastbit_Z(state):
    return 2 * (torch.norm(state[0:len(state):2])**2) - 1
def amplitude_embedding(feat, n_qubits):
    # feat : torch.tensor of 2^n_qubits elements
    if feat.ndim == 1:
        feat = feat.reshape(-1,).to(torch.complex64)
        feat /= torch.norm(feat)
    elif feat.ndim == 2:
        feat = feat.reshape(-1, 2**n_qubits,).to(torch.complex64)
        feat = feat.transpose(0,1) / torch.norm(feat, dim=1)
        feat = feat.transpose(0,1)
    return feat
def kronecker(A, B):
    if not isinstance(A, torch.Tensor):
        return B
    return torch.einsum("ab,cd->acbd", A, B).view(A.size(0)*B.size(0),  A.size(1)*B.size(1))
CZ = torch.tensor([
    [1,0,0,0],
    [0,1,0,0], 
    [0,0,1,0], 
    [0,0,0,-1]], dtype=torch.cfloat)
def CZ_layer(n_qubits):
    if n_qubits == 2:
        return CZ
    gate1 = CZ
    for i in range(2, n_qubits, 2):
        if i+1 < n_qubits:
            gate1 = kronecker(gate1, CZ)
        else:
            gate1 = kronecker(gate1, I)
    gate2 = CZ
    gate2 = kronecker(I, gate2)
    for i in range(3, n_qubits, 2):
        if i+1 < n_qubits:
            gate2 = kronecker(gate2, CZ)
        else:
            gate2 = kronecker(gate2, I)
    return torch.mm(gate2, gate1)
Z = torch.tensor([
    [1.0, 0],
    [0, -1.0]
], dtype=torch.complex64)
X = torch.tensor([
    [0, 1.0],
    [1.0,  0]
], dtype=torch.complex64)
Y = torch.tensor([
    [0, -1.0j],
    [1.0j, 0]
], dtype=torch.complex64)
I = torch.eye(2, dtype=torch.complex64)
def Frax(n):
    n = n / torch.norm(n)
    return n[0] * X + n[1] * Y + n[2] * Z
def Frax_ansatz(n_qubits, param):
    # param : torch.Tensor of (n_qubits, 3)
    x = 1
    for i in range(n_qubits):
        x = kronecker(x, Frax(param[i]))
    return torch.mm(CZ_layer(n_qubits), x)
def replace_Frax_ansatz(n_qubits, measured_qubit, observable, param):
    x = 1
    for i in range(measured_qubit):
        x = kronecker(x, Frax(param[i]))
    if observable == 'X':
        x = kronecker(x, X)
    elif observable == 'Y':
        x = kronecker(x, Y)
    elif observable == 'Z':
        x = kronecker(x, Z)
    elif observable == 'XY':
        x = kronecker(x, (X+Y)/sqrt(2))
    elif observable == 'XZ':
        x = kronecker(x, (X+Z)/sqrt(2))
    elif observable == 'YZ':
        x = kronecker(x, (Y+Z)/sqrt(2))
    for i in range(measured_qubit+1, n_qubits):
        x = kronecker(x, Frax(param[i]))
    return torch.mm(CZ_layer(n_qubits), x)
class FraxClassify():
    def __init__(self, n_qubits, layer_size, world_size):
        self.n_qubits = n_qubits
        self.layer_size = layer_size
        self.params = (torch.zeros(layer_size, n_qubits, 3) + 1/sqrt(3)).to(torch.complex64)
        self.world_size =world_size
        self.train_acc = []
        self.test_acc = []    
    def fit(self, train):
        params = self.params
        train_feat, train_label = train
        x = amplitude_embedding(train_feat, self.n_qubits)
        for a in range(self.layer_size):
            for b in range(self.n_qubits):
                R = torch.zeros(3,3)
                for c in range(train_feat.shape[0]):
                    y = x[c]
                    for d in range(a):
                        y = Frax_ansatz(self.n_qubits, params[d]) @ y
                    rx = replace_Frax_ansatz(self.n_qubits, b, 'X', params[a]) @ y
                    ry = replace_Frax_ansatz(self.n_qubits, b, 'Y', params[a]) @ y
                    rz = replace_Frax_ansatz(self.n_qubits, b, 'Z', params[a]) @ y
                    rxy = replace_Frax_ansatz(self.n_qubits, b, 'XY', params[a]) @ y
                    rxz = replace_Frax_ansatz(self.n_qubits, b, 'XZ', params[a]) @ y
                    ryz = replace_Frax_ansatz(self.n_qubits, b, 'YZ', params[a]) @ y
                    for d in range(a+1, self.layer_size):
                        rx = Frax_ansatz(self.n_qubits, params[d]) @ rx
                        ry = Frax_ansatz(self.n_qubits, params[d]) @ ry       
                        rz = Frax_ansatz(self.n_qubits, params[d]) @ rz
                        rxy = Frax_ansatz(self.n_qubits, params[d]) @ rxy
                        rxz = Frax_ansatz(self.n_qubits, params[d]) @ rxz        
                        ryz = Frax_ansatz(self.n_qubits, params[d]) @ ryz                       
                    rx = lastbit_Z(rx)
                    ry = lastbit_Z(ry)
                    rz = lastbit_Z(rz)
                    rxy = lastbit_Z(rxy)
                    rxz = lastbit_Z(rxz)
                    ryz = lastbit_Z(ryz)                       
                    R[0,0] += train_label[c] * 2 * rx
                    R[0,1] += train_label[c] * (2 * rxy - rx - ry)
                    R[0,2] += train_label[c] * (2 * rxz - rx - rz)
                    R[1,1] += train_label[c] * 2 * ry
                    R[1,2] += train_label[c] * (2 * ryz - ry - rz)
                    R[2,1] += train_label[c] * 2 * rz                 
                R[1,0] = R[0,1]
                R[2,0] = R[0,2]
                R[2,1] = R[1,2]
                group = dist.new_group(range(self.world_size))
                dist.all_reduce(R, op=dist.ReduceOp.SUM, group=group)
                eigenvalues, eigenvectors = torch.linalg.eigh(R)
                self.params[a, b] = eigenvectors[:, torch.argmax(eigenvalues)]
                (vgb,_), (sgb,_) = self.eval(train,train)
                if dist.get_rank() == 0:
                    print(torch.max(eigenvalues))
                    print(vgb, sgb)
                            
    def eval(self, train, test):
        group = dist.new_group(range(self.world_size))
        cri = torch.zeros(4)
        train_feat, train_label = train
        test_feat, test_label = test
        train_size = train_label.shape[0]
        test_size = test_label.shape[0]
        for a in range(test_size):
            x = amplitude_embedding(test_feat[a], self.n_qubits)
            for b in range(self.layer_size):
                x = Frax_ansatz(self.n_qubits, self.params[b]) @ x
            if test_label[a] * lastbit_Z(x) > 0:
                cri[1] += 1
            cri[3] += test_label[a] * lastbit_Z(x)
        for a in range(train_size):
            x = amplitude_embedding(train_feat[a], self.n_qubits)
            for b in range(self.layer_size):
                x = Frax_ansatz(self.n_qubits, self.params[b]) @ x
            if train_label[a] * lastbit_Z(x) > 0:
                cri[0] += 1
            cri[2] += train_label[a] * lastbit_Z(x)
        dist.all_reduce(cri, op=dist.ReduceOp.SUM, group=group)
        return (cri[0], cri[1]), (2*cri[2], 2*cri[3])

In [65]:
W = 4 # World_size
L = 4 # Layer_size
N = 10 # Iteration_size

In [66]:
processes = []
st = time.time()
for rank in range(W):
    p = mp.Process(target=init_process, args=(rank, W, L, N, parallel_train))
    p.start()
    processes.append(p)
        
for p in processes:
    p.join()
        
print('Implementation time : ', time.time()-st)

I am I am I am  I am  3
2 
 01

tensor(1.9074e-06)
tensor(21.) tensor(-31.0530)
tensor(-9.5362e-07)
tensor(21.) tensor(-31.0530)
tensor(8.2442)
tensor(94.) tensor(22.0340)
tensor(30.0270)
tensor(113.) tensor(39.0208)
tensor(12.6113)
tensor(115.) tensor(38.5114)


Process Process-53:
Process Process-56:
Traceback (most recent call last):
Process Process-55:
Process Process-54:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/usr/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/usr/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/usr/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "<ipython-input-62-db211145138a>", line 6, in init_process
    fn(rank, world_size, layer_size, update_iter)
  File "/usr/lib/python3.7/multiprocessing/proce

KeyboardInterrupt: ignored

In [61]:
v = torch.tensor([[-7.4738e-01+0.j, -6.6439e-01+0.j, -1.3709e-06+0.j],
         [ 7.2640e-01+0.j,  6.8727e-01+0.j,  1.5497e-06+0.j],
         [ 5.3045e-01+0.j,  8.4772e-01+0.j,  1.6689e-06+0.j],
         [-6.5305e-01+0.j, -7.5731e-01+0.j, -1.0133e-06+0.j],
         [ 6.8999e-01+0.j,  7.2382e-01+0.j,  1.4305e-06+0.j],
         [ 4.1626e-01+0.j,  7.7486e-07+0.j,  9.0925e-01+0.j]])
(train_feat, train_label), _, _, _ = data_loader(0, 1)
x = amplitude_embedding(train_feat, 6)
R = torch.zeros(3,3)
cccc = 5
for c in range(train_feat.shape[0]):
    y = x[c]
    rx = lastbit_Z(replace_Frax_ansatz(6, cccc, 'X', v) @ y)
    ry = lastbit_Z(replace_Frax_ansatz(6, cccc, 'Y', v) @ y)
    rz = lastbit_Z(replace_Frax_ansatz(6, cccc, 'Z', v) @ y)
    rxy = lastbit_Z(replace_Frax_ansatz(6, cccc, 'XY', v) @ y)
    rxz = lastbit_Z(replace_Frax_ansatz(6, cccc, 'XZ', v) @ y)
    ryz = lastbit_Z(replace_Frax_ansatz(6, cccc, 'YZ', v) @ y)
                                          
    R[0,0] += train_label[c] * 2 * rx
    R[0,1] += train_label[c] * (2 * rxy - rx - ry)
    R[0,2] += train_label[c] * (2 * rxz - rx - rz)
    R[1,1] += train_label[c] * 2 * ry
    R[1,2] += train_label[c] * (2 * ryz - ry - rz)
    R[2,1] += train_label[c] * 2 * rz                 
R[1,0] = R[0,1]
R[2,0] = R[0,2]
R[2,1] = R[1,2]
                
eigenvalues, eigenvectors = torch.linalg.eigh(R)
print(eigenvalues)
print(eigenvectors)
print(torch.dot(eigenvectors[:,2], torch.mv(R, eigenvectors[:,2])))
score = 0
w = v
w[cccc] = eigenvectors[:,2]
for c in range(train_feat.shape[0]):
    y = x[c]
    score += train_label[c] * lastbit_Z(Frax_ansatz(6, w) @ y)
print(2*score)

tensor([-64.1412, -50.6982,  13.4429])
tensor([[ 9.0925e-01, -3.8973e-07,  4.1626e-01],
        [-8.0466e-07, -1.0000e+00,  7.7486e-07],
        [-4.1626e-01,  1.0729e-06,  9.0925e-01]])
tensor(13.4429)
tensor(55.3568)


In [40]:
M=100
nrand = torch.rand(M, 3)
maxscore = 0
for i in range(M):
    cb = nrand[i] / torch.norm(nrand[i])
    if maxscore < torch.dot(cb, torch.mv(R, cb)):
        maxscore = torch.dot(cb, torch.mv(R, cb))
print(maxscore)

tensor(12.6407)
