<a href="https://colab.research.google.com/github/Izuho/senior-thesis/blob/main/multiClassify.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch.multiprocessing as mp
import torch.distributed as dist
import time
import os
import torch
import numpy as np
from math import sqrt
from torchvision import datasets, transforms
import cv2

In [20]:
MM = 4

In [21]:
mnist_train = datasets.MNIST(
    root='./data',
    download=True,
    train=True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)
mnist_test = datasets.MNIST(
    root='./data',
    download=True,
    train=False,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

x_train = mnist_train.train_data.numpy()
y_train = mnist_train.train_labels.numpy()
x_test = mnist_test.test_data.numpy()
y_test = mnist_test.test_labels.numpy()

x_train = x_train[np.where(y_train < MM, True, False)]
y_train = y_train[np.where(y_train < MM, True, False)]
x_test = x_test[np.where(y_test < MM, True, False)]
y_test = y_test[np.where(y_test < MM, True, False)]

N = x_train.shape[0]
M = x_test.shape[0]
new_x_train = np.zeros((N,8,8))
new_x_test = np.zeros((M,8,8))
for i in range(N):
    new_x_train[i] = cv2.resize(x_train[i], (8,8), cv2.INTER_AREA)
for i in range(M):
    new_x_test[i] = cv2.resize(x_test[i], (8,8), cv2.INTER_AREA)

new2_x_test = new_x_test.reshape(-1,64)
new2_x_train = new_x_train.reshape(-1,64)

new3_x_test = new2_x_test / np.sqrt(np.sum(new2_x_test*new2_x_test, axis = 1)).reshape(-1,1)
new3_x_train = new2_x_train / np.sqrt(np.sum(new2_x_train*new2_x_train, axis = 1)).reshape(-1,1)

save_root = './'
if not os.path.exists(save_root):
    os.mkdir(save_root)
np.save(os.path.join(save_root, 'drive/MyDrive/mnist_x_train'), new3_x_train)
np.save(os.path.join(save_root, 'drive/MyDrive/mnist_y_train'), y_train)
np.save(os.path.join(save_root, 'drive/MyDrive/mnist_x_test'), new3_x_test)
np.save(os.path.join(save_root, 'drive/MyDrive/mnist_y_test'), y_test)

In [26]:
TR_SIZE = 800
TE_SIZE = 200

In [27]:
def init_process(rank, world_size, layer_size, update_iter, Q, fn, backend='gloo'):
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=world_size)
    fn(rank, world_size, layer_size, update_iter, Q)
def parallel_train(rank, world_size, layer_size, update_iter, Q):
    print('I am ', rank)
    train, test, train_size, test_size = data_loader(rank, world_size)
    model = FraxClassify(layer_size, world_size, Q)
    acc = []
    for i in range(update_iter):
        model.fit(train=train)
        (test_acc, test_score), (train_acc, train_score) = model.eval(train=train, test=test)
        if rank == 0:
            print(train_acc / train_size, test_acc / test_size, train_score, test_score)
def data_loader(rank, world_size):
    try:
        test_label = torch.from_numpy(np.load('drive/MyDrive/mnist_y_test.npy'))[0:TE_SIZE]
        train_label = torch.from_numpy(np.load('drive/MyDrive/mnist_y_train.npy'))[0:TR_SIZE]
        test_feat = torch.from_numpy(np.load('drive/MyDrive/mnist_x_test.npy'))[0:TE_SIZE]
        train_feat = torch.from_numpy(np.load('drive/MyDrive/mnist_x_train.npy'))[0:TR_SIZE]
    except Exception as e:
        print(e)
    data_len_min = len(train_feat) // world_size
    offset = len(train_feat) % world_size
    if rank < offset:
        start1 = rank*(data_len_min+1)
        end1 = start1+data_len_min+1
    else:
        start1 = offset*(data_len_min+1)+(rank-offset)*data_len_min
        end1 = start1+data_len_min
    data_len_min = len(test_feat) // world_size
    offset = len(test_feat) % world_size
    if rank < offset:
        start2 = rank*(data_len_min+1)
        end2 = start2+data_len_min+1
    else:
        start2 = offset*(data_len_min+1)+(rank-offset)*data_len_min
        end2 = start2+data_len_min
    return (train_feat[start1:end1], train_label[start1:end1]), (test_feat[start2:end2], test_label[start2:end2]), train_label.shape[0], test_label.shape[0]
def bit_Z(state, bit):
    ans = 0
    for i in range(2**(bit-1)):
        ans += 2 * (torch.norm(state[i:len(state):2**bit])**2) - 1
    return ans
def amplitude_embedding(feat, n, n_qubits):
    feat = feat.reshape(-1,).to(torch.complex64)
    feat = torch.repeat_interleave(feat, 2**(n_qubits-n))
    feat /= torch.norm(feat)
    return feat
def frax_embedding(feat, n, n_qubits):
    count = 0
    ans = torch.eye(2**n_qubits).to(torch.complex64)
    x = 1
    #assert 2**n % n_qubits == 0, 'Error from frax_embedding!'
    for i in range(0, feat.shape[0], 2):
        n = torch.zeros(3).to(torch.complex64)
        n[0], n[1] = feat[i].to(torch.complex64), feat[i+1].to(torch.complex64)
        n[2] = torch.sqrt(1-n[0]**2-n[1]**2)
        x = kronecker(x, Frax(n))
        if (count+1) % n_qubits == 0:
            ans = CZ_layer(n_qubits) @ ans @ x
            x = 1
        count += 1
    return ans
def kronecker(A, B):
    if not isinstance(A, torch.Tensor):
        return B
    return torch.einsum("ab,cd->acbd", A, B).view(A.size(0)*B.size(0),  A.size(1)*B.size(1))
CZ = torch.tensor([
    [1,0,0,0],
    [0,1,0,0], 
    [0,0,1,0], 
    [0,0,0,-1]], dtype=torch.cfloat)
def CZ_layer(n_qubits):
    if n_qubits == 2:
        return CZ
    gate1 = CZ
    for i in range(2, n_qubits, 2):
        if i+1 < n_qubits:
            gate1 = kronecker(gate1, CZ)
        else:
            gate1 = kronecker(gate1, I)
    gate2 = CZ
    gate2 = kronecker(I, gate2)
    for i in range(3, n_qubits, 2):
        if i+1 < n_qubits:
            gate2 = kronecker(gate2, CZ)
        else:
            gate2 = kronecker(gate2, I)
    return torch.mm(gate2, gate1)
X = torch.tensor([[0,1],[1,0]], dtype=torch.complex64)
Y = torch.tensor([[0,-1j],[1j,0]], dtype=torch.complex64)
Z = torch.tensor([[1,0],[0,-1]], dtype=torch.complex64)
XY = (X+Y)/sqrt(2)
XZ = (X+Z)/sqrt(2)
YZ = (Y+Z)/sqrt(2)
I = torch.eye(2, dtype=torch.complex64)
def Frax(n):
    n = n / torch.norm(n)
    return n[0] * X + n[1] * Y + n[2] * Z
def Frax_ansatz(n_qubits, param):
    # param : torch.Tensor of (n_qubits, 3)
    x = 1
    for i in range(n_qubits):
        x = kronecker(x, Frax(param[i]))
    return torch.mm(CZ_layer(n_qubits), x)
def replace_Frax_ansatz(n_qubits, measured_qubit, observable, param):
    x = 1
    for i in range(measured_qubit):
        x = kronecker(x, Frax(param[i]))
    if observable == 'X':
        x = kronecker(x, X)
    elif observable == 'Y':
        x = kronecker(x, Y)
    elif observable == 'Z':
        x = kronecker(x, Z)
    elif observable == 'XY':
        x = kronecker(x, XY)
    elif observable == 'XZ':
        x = kronecker(x, XZ)
    elif observable == 'YZ':
        x = kronecker(x, YZ)
    for i in range(measured_qubit+1, n_qubits):
        x = kronecker(x, Frax(param[i]))
    return torch.mm(CZ_layer(n_qubits), x)
#converter = [[1],[-1]]
converter = [[-1,-1],[-1,1],[1,-1],[1,1]]
#converter = [(-1,-1,-1), (-1,-1,1), (-1,1,-1), (-1,1,1), (1,-1,-1), (1,-1,1), (1,1,-1), (1,1,1)]
class FraxClassify():
    def __init__(self, layer_size, world_size, Q):
        self.layer_size = layer_size
        self.params = (torch.zeros(layer_size, Q, 3) + 1/sqrt(3)).to(torch.complex64)
        self.world_size = world_size
        self.n_qubits = Q
    def fit(self, train):
        params = self.params
        train_feat, train_label = train
        for a in range(self.layer_size):
            for b in range(self.n_qubits):
                R = torch.zeros(3,3)
                for c in range(train_feat.shape[0]):
                    #y = amplitude_embedding(train_feat[c])
                    y = frax_embedding(train_feat[c], 6, self.n_qubits)[:,0]
                    for d in range(a):
                        y = Frax_ansatz(self.n_qubits, params[d]) @ y
                        y = frax_embedding(train_feat[c], 6, self.n_qubits) @ y
                    rx = replace_Frax_ansatz(self.n_qubits, b, 'X', params[a]) @ y
                    ry = replace_Frax_ansatz(self.n_qubits, b, 'Y', params[a]) @ y
                    rz = replace_Frax_ansatz(self.n_qubits, b, 'Z', params[a]) @ y
                    rxy = replace_Frax_ansatz(self.n_qubits, b, 'XY', params[a]) @ y
                    rxz = replace_Frax_ansatz(self.n_qubits, b, 'XZ', params[a]) @ y
                    ryz = replace_Frax_ansatz(self.n_qubits, b, 'YZ', params[a]) @ y
                    for d in range(a+1, self.layer_size):
                        rx = frax_embedding(train_feat[c], 6, self.n_qubits) @ rx
                        ry = frax_embedding(train_feat[c], 6, self.n_qubits) @ ry
                        rz = frax_embedding(train_feat[c], 6, self.n_qubits) @ rz
                        rxy = frax_embedding(train_feat[c], 6, self.n_qubits) @ rxy
                        rxz = frax_embedding(train_feat[c], 6, self.n_qubits) @ rxz
                        ryz = frax_embedding(train_feat[c], 6, self.n_qubits) @ ryz
                        rx = Frax_ansatz(self.n_qubits, params[d]) @ rx
                        ry = Frax_ansatz(self.n_qubits, params[d]) @ ry       
                        rz = Frax_ansatz(self.n_qubits, params[d]) @ rz
                        rxy = Frax_ansatz(self.n_qubits, params[d]) @ rxy
                        rxz = Frax_ansatz(self.n_qubits, params[d]) @ rxz        
                        ryz = Frax_ansatz(self.n_qubits, params[d]) @ ryz
                    '''               
                    rxs = (bit_Z(rx, 1), bit_Z(rx, 3), bit_Z(rx, 5))
                    rys = (bit_Z(ry, 1), bit_Z(ry, 3), bit_Z(ry, 5))
                    rzs = (bit_Z(rz, 1), bit_Z(rz, 3), bit_Z(rz, 5))
                    rxys = (bit_Z(rxy, 1), bit_Z(rxy, 3), bit_Z(rxy, 5))
                    rxzs = (bit_Z(rxz, 1), bit_Z(rxz, 3), bit_Z(rxz, 5))
                    ryzs = (bit_Z(ryz, 1), bit_Z(ryz, 3), bit_Z(ryz, 5))
                    '''
                    rxs = [bit_Z(rx, 1), bit_Z(rx, 6)]
                    rys = [bit_Z(ry, 1), bit_Z(ry, 6)]
                    rzs = [bit_Z(rz, 1), bit_Z(rz, 6)]
                    rxys = [bit_Z(rxy, 1), bit_Z(rxy, 6)]
                    rxzs = [bit_Z(rxz, 1), bit_Z(rxz, 6)]
                    ryzs = [bit_Z(ryz, 1), bit_Z(ryz, 6)]
                    for d in range(2):
                    #for d in range(3):
                        R[0,0] += converter[train_label[c]][d] * 2 * rxs[d]
                        R[0,1] += converter[train_label[c]][d] * (2 * rxys[d] - rxs[d] - rys[d])
                        R[0,2] += converter[train_label[c]][d] * (2 * rxzs[d] - rxs[d] - rzs[d])
                        R[1,1] += converter[train_label[c]][d] * 2 * rys[d]
                        R[1,2] += converter[train_label[c]][d] * (2 * ryzs[d] - rys[d] - rzs[d])
                        R[2,2] += converter[train_label[c]][d] * 2 * rzs[d]          
                R[1,0] = R[0,1]
                R[2,0] = R[0,2]
                R[2,1] = R[1,2]
                group = dist.new_group(range(self.world_size))
                dist.all_reduce(R, op=dist.ReduceOp.SUM, group=group)
                eigenvalues, eigenvectors = torch.linalg.eigh(R)
                self.params[a, b] = eigenvectors[:, torch.argmax(eigenvalues)]
                if dist.get_rank() == 0: 
                    print(torch.max(eigenvalues))
    def eval(self, train, test):
        group = dist.new_group(range(self.world_size))
        cri = torch.zeros(4)
        train_feat, train_label = train
        test_feat, test_label = test
        train_size = train_label.shape[0]
        test_size = test_label.shape[0]
        for a in range(test_size):
            # x = amplitude_embedding(test_feat[a], 6, self.n_qubits)
            x = torch.eye(2**self.n_qubits).to(torch.complex64)
            for b in range(self.layer_size):
                x = frax_embedding(test_feat[a], 6, self.n_qubits) @ x
                x = Frax_ansatz(self.n_qubits, self.params[b]) @ x
            x = x[:,0]
            if converter[test_label[a]][0] * bit_Z(x,1) > 0 and converter[test_label[a]][1] * bit_Z(x,6) > 0:# and converter[test_label[a]][2] * bit_Z(x[a],5) > 0:
                cri[0] += 1
            cri[1] += converter[test_label[a]][0] * bit_Z(x, 1) + converter[test_label[a]][1] * bit_Z(x, 6)# + converter[test_label[a]][2] * bit_Z(x[a], 5)
        for a in range(train_size):
            # x = amplitude_embedding(train_feat[a], 6, self.n_qubits)
            x = torch.eye(2**self.n_qubits).to(torch.complex64)
            for b in range(self.layer_size):
                x = frax_embedding(train_feat[a], 6, self.n_qubits) @ x
                x = Frax_ansatz(self.n_qubits, self.params[b]) @ x
            x = x[:,0]
            if converter[train_label[a]][0] * bit_Z(x,1) > 0 and converter[train_label[a]][1] * bit_Z(x,6) > 0:# and converter[train_label[a]][2] * bit_Z(x[a],5) > 0:
                cri[2] += 1
            cri[3] += converter[train_label[a]][0] * bit_Z(x, 1) + converter[train_label[a]][1] * bit_Z(x, 6)# + converter[train_label[a]][2] * bit_Z(x[a], 5)
        dist.all_reduce(cri, op=dist.ReduceOp.SUM, group=group)
        return (cri[0], 2*cri[1]), (cri[2], 2*cri[3])

In [28]:
W = 4 # World_size
L = 1 # Layer_size
N = 5 # Iteration_size
Q = 6

In [29]:
processes = []
st = time.time()
for rank in range(W):
    p = mp.Process(target=init_process, args=(rank, W, L, N, Q, parallel_train))
    p.start()
    processes.append(p)
        
for p in processes:
    p.join()
        
print('Implementation time : ', time.time()-st)

I am  2
I am  I am I am   310


tensor(-1198.9880)
tensor(-1198.9885)
tensor(-1198.9883)
tensor(-1198.9885)
tensor(-1198.9883)
tensor(-1140.5768)
tensor(0.2663) tensor(0.2250) tensor(-1140.5768) tensor(-913.8596)
tensor(-1140.5769)
tensor(-1140.5768)
tensor(-1140.5770)
tensor(-1140.5764)
tensor(-1140.5765)
tensor(-1140.5769)
tensor(0.2663) tensor(0.2250) tensor(-1140.5778) tensor(-913.8597)
tensor(-1140.5767)
tensor(-1140.5763)
tensor(-1140.5760)
tensor(-1140.5763)
tensor(-1140.5764)
tensor(-1140.5769)
tensor(0.2663) tensor(0.2250) tensor(-1140.5771) tensor(-913.8600)
tensor(-1140.5771)
tensor(-1140.5764)
tensor(-1140.5764)
tensor(-1140.5769)
tensor(-1140.5771)
tensor(-1140.5769)
tensor(0.2663) tensor(0.2250) tensor(-1140.5773) tensor(-913.8597)
tensor(-1140.5770)
tensor(-1140.5773)
tensor(-1140.5771)
tensor(-1140.5768)
tensor(-1140.5769)
tensor(-1140.5771)
tensor(0.2663) tensor(0.2250) tensor(-1140.5773) tensor(-913.8596)
Implementation time :  513.9678840637207
