In [1]:
import bitstring
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split

import initializations as init
import compression as comp

In [2]:
# We need input and output training data
X, y = fetch_openml('mnist_784', version=1, return_X_y=True)
train_samples = 5000

X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_samples, test_size=10000)

In [6]:
### VERSION: WEIGHT MATRIX OF MPOS, BINARY INPUT ###

# We need to initialize each element of the weight matrix as an MPO
bits = 8  # FLOATS MUST BE 32 OR 64 BIT
W = []
MPO = init.initialize_random_MPO(num_sites=bits, bond_dim=1, phys_dim=2)

layer_inputs = 784
layer_outputs = 10
for i in range(layer_inputs):
    line = []
    for j in range(layer_outputs):
        line.append(MPO[:])
    W.append(line)
print("W Shape:", len(W), "x", len(W[0]))

def convert_vector_to_binary(vector, bits):
    vector_bin = []
    for value in vector:
        value_bin = bitstring.BitArray(float=value, length=bits)  # Add float=value for floats
        vector_bin.append(np.fromstring(value_bin.bin, np.int8) - 48)
    return vector_bin

def create_MPS_from_bit_string(vector_bin, IO):
    MPS_list = []
    for value in vector_bin:
        MPS = []
        for i, element in enumerate(value):
            if element == 0:
                MPS.append(np.array([1, 0]))
            elif element == 1:
                MPS.append(np.array([0, 1]))

            if IO == 'output':
                if i == 0 or i == len(value)-1:
                    MPS[-1] = MPS[-1][..., np.newaxis]
                else:
                    MPS[-1] = MPS[-1][..., np.newaxis, np.newaxis]
                    MPS[-1] = np.transpose(MPS[-1], (1, 2, 0))  # Maybe needs to be 2, 1, 0
        MPS_list.append(MPS)
    return MPS_list

for k in range(1):
    # Preprocessing for a given datapoint
    training_input = X_train[k]

    training_output = np.zeros(10)
    training_output[int(y_train[k])] = 1

    training_input = training_input.astype(int)
    training_output = training_output.astype(int)

    # Convert values to binary
    training_input_bin = []
    for value in training_input:
        input_bin = [int(x) for x in '{:08b}'.format(value)]
        training_input_bin.append(input_bin)

    training_output_bin = []
    for value in training_output:
        output_bin = [int(x) for x in '{:08b}'.format(value)]
        training_output_bin.append(output_bin)

    # Create input/output MPS for our training datapoint
    x = create_MPS_from_bit_string(training_input_bin, IO='input')
    y = create_MPS_from_bit_string(training_output_bin, IO='output')

    index_y = int(y_train[k])
    print(k, index_y)
    for index_x in range(len(x)):
        # Element to be updated
        weight = W[index_x][index_y]
        
        # Binary value corresponding to pixel value
        input_x = x[index_x]
        xW = []
        for i in range(len(weight)):
            if i == 0 or i == len(weight)-1:
                # Contract from top
                site = np.einsum('i, abi->ab', input_x[i], weight[i])
                # Transpose since we move from MPO to MPS index notation
                site = site.T
                xW.append(site)
            else:
                # Contract from top
                site = np.einsum('i, abci->abc', input_x[i], weight[i])
                xW.append(site)

        # Compress weights corresponding to given value in output
        output_y = y[index_y]
        compressed_xW, best_dist, best_sim = comp.compress(output_y, threshold=1e-3, compressed_state=xW[:], plot=0)
        if best_dist[-1] != 0.0:
            print(best_dist[-1])

        # Uses best compressed xW
        updated_xW = compressed_xW[-2]

        # We need to remove x from xW such that x^T xW = W
        x_T = []
        for i, element in enumerate(input_x):
            tranposed = element[..., np.newaxis].T
            x_T.append(tranposed)

        expanded_xW = []
        for i, element in enumerate(updated_xW):
            if i == 0 or i == len(updated_xW)-1:
                # Transpose since we move from MPS to MPO index notation
                element = element.T
                element = element[..., np.newaxis]
            else:
                element = element[..., np.newaxis]
            expanded_xW.append(element)

        updated_weight = []
        for i in range(len(expanded_xW)):
            if i == 0 or i == len(expanded_xW)-1:
                # Contract from top
                site = np.einsum('ij, abi->abj', x_T[i], expanded_xW[i])
                updated_weight.append(site)
            else:
                # Contract from top
                site = np.einsum('ij, abci->abcj', x_T[i], expanded_xW[i])
                updated_weight.append(site)

        W[index_x][index_y] = updated_weight

for i in range(len(W)):
    for j in range(len(W[0])):
        for k in range(len(W[0][0])):
            print(W[i][j][k].all() == W[0][0][k].all())

ist: 0.0 BondDim: 2
Sim: 1.0 Dist: 0.0 BondDim: 1
Sim: 1.0 Dist: 0.0 BondDim: 2
Sim: 1.0 Dist: 0.0 BondDim: 1
Sim: 1.0 Dist: 0.0 BondDim: 2
Sim: 1.0 Dist: 0.0 BondDim: 1
Sim: 1.0 Dist: 0.0 BondDim: 2
Sim: 1.0 Dist: 0.0 BondDim: 1
Sim: 1.0 Dist: 0.0 BondDim: 2
Sim: 1.0 Dist: 0.0 BondDim: 1
Sim: 1.0 Dist: 0.0 BondDim: 2
Sim: 1.0 Dist: 0.0 BondDim: 1
Sim: 1.0 Dist: 0.0 BondDim: 2
Sim: 1.0 Dist: 0.0 BondDim: 1
Sim: 1.0 Dist: 0.0 BondDim: 2
Sim: 1.0 Dist: 0.0 BondDim: 1
Sim: 1.0 Dist: 0.0 BondDim: 2
Sim: 1.0 Dist: 0.0 BondDim: 1
Sim: 1.0 Dist: 0.0 BondDim: 2
Sim: 1.0 Dist: 0.0 BondDim: 1
Sim: 1.0 Dist: 0.0 BondDim: 2
Sim: 1.0 Dist: 0.0 BondDim: 1
Sim: 1.0 Dist: 0.0 BondDim: 2
Sim: 1.0 Dist: 0.0 BondDim: 1
Sim: 1.0 Dist: 0.0 BondDim: 2
Sim: 1.0 Dist: 0.0 BondDim: 1
Sim: 1.0 Dist: 0.0 BondDim: 2
Sim: 1.0 Dist: 0.0 BondDim: 1
Sim: 1.0 Dist: 0.0 BondDim: 2
Sim: 1.0 Dist: 0.0 BondDim: 1
Sim: 1.0 Dist: 0.0 BondDim: 2
Sim: 1.0 Dist: 0.0 BondDim: 1
Sim: 1.0 Dist: 0.0 BondDim: 2
Sim: 1.0 Dist: 0.0 B

In [9]:
X_train[0].shape

(784,)

In [99]:
def vector_to_left_canonical_MPS_NN(tensor, phys_dim):
    """ Decomposes a vector of length d^L (phys_dim^num_sites) into a
        left-canonical MPS. Final site will not be canonical due to
        original norm

    Args:
        tensor: Vector of length that can be described by d^L (Ex: 512 = 2^9)
        phys_dim: Physical dimension necessary on MPS at each site (d)
        num_sites: Number of sites necessary (L)

    Returns:
        A_tensors: Left canonical form of input MPS
    """

    A_tensors = []
    num_sites = len(phys_dim)
    for i in range(0, num_sites-1):
        # Remove one leg such that tensor has shape (d, d^(L-1)) with L sites
        if i == 0:
            reshaped_tensor = np.reshape(tensor, (phys_dim[i],
                                                  tensor.shape[0]//phys_dim[i]))
        else:
            reshaped_tensor = np.reshape(tensor, (A_tensors[-1].shape[1]*phys_dim[i],
                                                  tensor.shape[1]//phys_dim[i]))       
        #print("Reshape:", reshaped_tensor.shape)

        # SVD and save the rank for the next iteration of the loop
        U, S_vector, V = np.linalg.svd(reshaped_tensor, full_matrices=False)
        rank = len(S_vector)

        if i == 0:
            # No need to reshape since U is already a left-canonical matrix
            A_tensors.append(U)
        else:
            # Break apart first leg of U into a left bond dimension
            # and physical dimension
            U = np.reshape(U, (A_tensors[-1].shape[1], phys_dim[i], U.shape[1]))
            # Transpose so that we have the correct shape
            # (left bond, right bond, physical dimension)
            U = np.transpose(U, (0, 2, 1))
            A_tensors.append(U)

        # We recreate the tensor with the remaining legs
        tensor = np.diag(S_vector) @ V
    # Final A tensor is the remaining tensor after all other legs removed
    A_tensors.append(tensor)

    return A_tensors

In [100]:
vector_to_left_canonical_MPS(X_train[0], phys_dim=[4,7,7,4], num_sites=4)

(4, 4)
(4, 28, 7)
(28, 4, 7)
(4, 4)


561e-02,  1.94928692e-02, -1.26585688e-01,
           1.91273736e-01, -7.03868466e-02, -2.25405086e-01,
           2.63915566e-01],
         [-1.41773400e-02,  2.09944143e-02, -4.24141429e-01,
           2.94342037e-01,  7.93449690e-02,  2.07762804e-01,
          -7.60771606e-02],
         [ 2.90353283e-01,  1.83772411e-01,  6.01320093e-02,
          -1.01739119e-01, -6.57470729e-03,  1.23860772e-01,
           1.03438060e-01],
         [ 3.35143656e-01,  2.67740989e-02, -1.20026752e-02,
           8.16575691e-02, -1.26029690e-01,  1.88267178e-01,
          -1.16539392e-01],
         [-4.90547044e-01, -6.19907176e-02,  1.28709806e-01,
           7.14807705e-02, -1.20465830e-01, -6.05553085e-03,
          -7.79432653e-02],
         [ 2.94972680e-02, -1.09934065e-01, -1.69688160e-01,
           9.45911014e-02,  9.09648664e-02,  3.92858226e-02,
          -3.51156287e-01],
         [ 1.35535035e-01,  7.06008383e-02,  3.27907283e-01,
           6.36480079e-02,  1.49172763e-02, -1.02391162e-