<a href="https://colab.research.google.com/github/Izuho/senior-thesis/blob/main/Test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [65]:
import torch.multiprocessing as mp
import torch.distributed as dist
import time
import os
import torch
import numpy as np
from math import sqrt

In [69]:
def init_process(rank, world_size, layer_size, update_iter, n_qubits, fn, backend='gloo'):
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=world_size)
    fn(rank, world_size, layer_size, update_iter, n_qubits)
def parallel_train(rank, world_size, layer_size, update_iter, n_qubits):
    print('I am ', rank)
    train, test, train_size, test_size = data_loader(rank, world_size)
    model = FraxClassify(n_qubits, layer_size, world_size)
    acc = []
    for i in range(update_iter):
        model.fit(train=train)
        (train_acc, test_acc), (train_score, test_score) = model.eval(train=train, test=test)
        if rank == 0:
            print(train_acc / train_size, test_acc / test_size, train_score, test_score)
def data_loader(rank, world_size):
    try:
        test_label = torch.from_numpy(np.load('drive/MyDrive/mnist_test_Label.npy'))[0:200]
        train_label = torch.from_numpy(np.load('drive/MyDrive/mnist_train_Label.npy'))[0:800]
        test_feat = torch.from_numpy(np.load('drive/MyDrive/mnist_test_feat.npy'))[0:200]
        train_feat = torch.from_numpy(np.load('drive/MyDrive/mnist_train_feat.npy'))[0:800]
    except Exception as e:
        print(e)
    data_len_min = len(train_feat) // world_size
    offset = len(train_feat) % world_size
    if rank < offset:
        start1 = rank*(data_len_min+1)
        end1 = start1+data_len_min+1
    else:
        start1 = offset*(data_len_min+1)+(rank-offset)*data_len_min
        end1 = start1+data_len_min
    data_len_min = len(test_feat) // world_size
    offset = len(test_feat) % world_size
    if rank < offset:
        start2 = rank*(data_len_min+1)
        end2 = start2+data_len_min+1
    else:
        start2 = offset*(data_len_min+1)+(rank-offset)*data_len_min
        end2 = start2+data_len_min
    return (train_feat[start1:end1], train_label[start1:end1]), (test_feat[start2:end2], test_label[start2:end2]), train_label.shape[0], test_label.shape[0]
def lastbit_Z(state):
    return 2 * (torch.norm(state[0:len(state):2])**2) - 1
def amplitude_embedding(feat, n, n_qubits):
    # feat : torch.tensor of 2^n_qubits elements
    if feat.ndim == 1:
        feat = feat.reshape(-1,).to(torch.complex64)
        feat /= torch.norm(feat)
    elif feat.ndim == 2:
        feat = feat.reshape(-1,2**n,).to(torch.complex64)
        feat = feat.transpose(0,1) / torch.norm(feat, dim=1)
        feat = feat.transpose(0,1)
    return feat.tile((2**(n_qubits-n),))
def kronecker(A, B):
    if not isinstance(A, torch.Tensor):
        return B
    return torch.einsum("ab,cd->acbd", A, B).view(A.size(0)*B.size(0),  A.size(1)*B.size(1))
CZ = torch.tensor([
    [1,0,0,0],
    [0,1,0,0], 
    [0,0,1,0], 
    [0,0,0,-1]], dtype=torch.cfloat)
def CZ_layer(n_qubits):
    if n_qubits == 2:
        return CZ
    gate1 = CZ
    for i in range(2, n_qubits, 2):
        if i+1 < n_qubits:
            gate1 = kronecker(gate1, CZ)
        else:
            gate1 = kronecker(gate1, I)
    gate2 = CZ
    gate2 = kronecker(I, gate2)
    for i in range(3, n_qubits, 2):
        if i+1 < n_qubits:
            gate2 = kronecker(gate2, CZ)
        else:
            gate2 = kronecker(gate2, I)
    return torch.mm(gate2, gate1)
X = torch.tensor([[0,1],[1,0]], dtype=torch.complex64)
Y = torch.tensor([[0,-1j],[1j,0]], dtype=torch.complex64)
Z = torch.tensor([[1,0],[0,-1]], dtype=torch.complex64)
XY = (X+Y)/sqrt(2)
XZ = (X+Z)/sqrt(2)
YZ = (Y+Z)/sqrt(2)
I = torch.eye(2, dtype=torch.complex64)
def Frax(n):
    n = n / torch.norm(n)
    return n[0] * X + n[1] * Y + n[2] * Z
def Frax_ansatz(n_qubits, param):
    # param : torch.Tensor of (n_qubits, 3)
    x = 1
    for i in range(n_qubits):
        x = kronecker(x, Frax(param[i]))
    return torch.mm(CZ_layer(n_qubits), x)
def replace_Frax_ansatz(n_qubits, measured_qubit, observable, param):
    x = 1
    for i in range(measured_qubit):
        x = kronecker(x, Frax(param[i]))
    if observable == 'X':
        x = kronecker(x, X)
    elif observable == 'Y':
        x = kronecker(x, Y)
    elif observable == 'Z':
        x = kronecker(x, Z)
    elif observable == 'XY':
        x = kronecker(x, XY)
    elif observable == 'XZ':
        x = kronecker(x, XZ)
    elif observable == 'YZ':
        x = kronecker(x, YZ)
    for i in range(measured_qubit+1, n_qubits):
        x = kronecker(x, Frax(param[i]))
    return torch.mm(CZ_layer(n_qubits), x)
class FraxClassify():
    def __init__(self, n_qubits, layer_size, world_size):
        self.n_qubits = n_qubits
        self.layer_size = layer_size
        self.params = (torch.zeros(layer_size, n_qubits, 3) + 1/sqrt(3)).to(torch.complex64)
        self.world_size =world_size
    def fit(self, train):
        params = self.params
        train_feat, train_label = train
        x = amplitude_embedding(train_feat, 6, self.n_qubits)
        for a in range(self.layer_size):
            for b in range(self.n_qubits):
                R = torch.zeros(3,3)
                for c in range(train_feat.shape[0]):
                    y = x[c]
                    for d in range(a):
                        y = Frax_ansatz(self.n_qubits, params[d]) @ y
                    rx = replace_Frax_ansatz(self.n_qubits, b, 'X', params[a]) @ y
                    ry = replace_Frax_ansatz(self.n_qubits, b, 'Y', params[a]) @ y
                    rz = replace_Frax_ansatz(self.n_qubits, b, 'Z', params[a]) @ y
                    rxy = replace_Frax_ansatz(self.n_qubits, b, 'XY', params[a]) @ y
                    rxz = replace_Frax_ansatz(self.n_qubits, b, 'XZ', params[a]) @ y
                    ryz = replace_Frax_ansatz(self.n_qubits, b, 'YZ', params[a]) @ y
                    for d in range(a+1, self.layer_size):
                        rx = Frax_ansatz(self.n_qubits, params[d]) @ rx
                        ry = Frax_ansatz(self.n_qubits, params[d]) @ ry       
                        rz = Frax_ansatz(self.n_qubits, params[d]) @ rz
                        rxy = Frax_ansatz(self.n_qubits, params[d]) @ rxy
                        rxz = Frax_ansatz(self.n_qubits, params[d]) @ rxz        
                        ryz = Frax_ansatz(self.n_qubits, params[d]) @ ryz                       
                    rx = lastbit_Z(rx)
                    ry = lastbit_Z(ry)
                    rz = lastbit_Z(rz)
                    rxy = lastbit_Z(rxy)
                    rxz = lastbit_Z(rxz)
                    ryz = lastbit_Z(ryz)                       
                    R[0,0] += train_label[c] * 2 * rx
                    R[0,1] += train_label[c] * (2 * rxy - rx - ry)
                    R[0,2] += train_label[c] * (2 * rxz - rx - rz)
                    R[1,1] += train_label[c] * 2 * ry
                    R[1,2] += train_label[c] * (2 * ryz - ry - rz)
                    R[2,2] += train_label[c] * 2 * rz                 
                R[1,0] = R[0,1]
                R[2,0] = R[0,2]
                R[2,1] = R[1,2]
                group = dist.new_group(range(self.world_size))
                dist.all_reduce(R, op=dist.ReduceOp.SUM, group=group)
                eigenvalues, eigenvectors = torch.linalg.eigh(R)
                self.params[a, b] = eigenvectors[:, torch.argmax(eigenvalues)]
                if dist.get_rank() == 0: print(torch.max(eigenvalues))
    def eval(self, train, test):
        group = dist.new_group(range(self.world_size))
        cri = torch.zeros(4)
        train_feat, train_label = train
        test_feat, test_label = test
        train_size = train_label.shape[0]
        test_size = test_label.shape[0]
        for a in range(test_size):
            x = amplitude_embedding(test_feat[a], 6, self.n_qubits)
            for b in range(self.layer_size):
                x = Frax_ansatz(self.n_qubits, self.params[b]) @ x
            if test_label[a] * lastbit_Z(x) > 0:
                cri[1] += 1
            cri[3] += test_label[a] * lastbit_Z(x)
        for a in range(train_size):
            x = amplitude_embedding(train_feat[a], 6, self.n_qubits)
            for b in range(self.layer_size):
                x = Frax_ansatz(self.n_qubits, self.params[b]) @ x
            if train_label[a] * lastbit_Z(x) > 0:
                cri[0] += 1
            cri[2] += train_label[a] * lastbit_Z(x)
        dist.all_reduce(cri, op=dist.ReduceOp.SUM, group=group)
        return (cri[0], cri[1]), (2*cri[2], 2*cri[3])

In [72]:
W = 20 # World_size
L = 2 # Layer_size
N = 5 # Iteration_size
Q = 6 # N_qubits

In [73]:
processes = []
st = time.time()
for rank in range(W):
    p = mp.Process(target=init_process, args=(rank, W, L, N, Q, parallel_train))
    p.start()
    processes.append(p)
        
for p in processes:
    p.join()
        
print('Implementation time : ', time.time()-st)

I am I am I am   1I am 
 I am 0
 I am I am I am  I am I am I am 9I am I am I am 8
  10 I am I am I am I am  
    I am I am 13712 
 
4 
 311172 1415  


516
6
19


18




tensor(-208.4319)
tensor(-208.4319)
tensor(-208.4319)
tensor(-208.4319)
tensor(777.6494)
tensor(807.7505)
tensor(807.7506)
tensor(807.7508)
tensor(807.7507)
tensor(807.7506)
tensor(807.7506)
tensor(815.2742)
tensor(0.9287) tensor(0.9150) tensor(815.2738) tensor(259.8004)
tensor(815.2740)
tensor(815.2742)
tensor(815.2739)
tensor(815.2741)
tensor(815.6616)
tensor(816.5688)
tensor(816.5691)
tensor(816.5688)
tensor(816.5690)
tensor(816.5693)
tensor(816.5693)
tensor(817.1424)
tensor(0.9287) tensor(0.9150) tensor(817.1423) tensor(260.0410)
tensor(817.1426)
tensor(817.1426)
tensor(817.1425)
tensor(817.1426)
tensor(817.1461)
tensor(817.7001)
tensor(817.7001)
tensor(817.7001)
tensor(817.7001)
tensor(817.6998)
tensor(817.7000)
tensor(818.2416)
tensor(0.9275) tensor(0.9150) tensor(818.2411) tensor(260.1296)
tensor(818.2415)
tenso