In [None]:
import sys
sys.path.insert(0, "/Users/maxge/Documents/Studium/München/02_SS 2024/QEL/Block encoding generalization/img-compression-mps/ND MPS Encoding")

In [6]:
from mps_ND import NDMPS
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt 
from utils_ND import *
import quimb.tensor as qtn
from scipy.fftpack import dct, idct
import time



In [7]:
def get_factorlist(shape):
    # todo rename to actually get_block_sizes
    factor_lists = []
    for dim_size in shape:
        if dim_size == 1:
            # Handle dimension of size 1 by assigning a factor of 1
            factors = [1]
        else:
            factors_dict = factorint(dim_size)
            # Extract prime factors and repeat them according to their exponents
            factors = []
            for prime, exponent in sorted(factors_dict.items()):
                factors.extend([prime] * exponent)
        # Sort the factors in ascending order
        factors_sorted = sorted(factors)
        factor_lists.append(factors_sorted)
    
    # Step 2: Balancing the number of factors across all dimensions
    # Determine the minimum number of factors among all dimensions
    min_factors = min(len(factors) for factors in factor_lists)
    
    # Balance factors by grouping the smallest factors in dimensions with more factors
    for idx, factors in enumerate(factor_lists):
        if len(factors) > min_factors:
            factor_lists[idx] = balance_factors(factors, min_factors)
        elif len(factors) < min_factors:
            # If a dimension has fewer factors, pad with 1s to reach min_factors
            # This effectively treats missing factors as trivial
            factor_lists[idx].extend([1] * (min_factors - len(factors)))
            factor_lists[idx] = sorted(factor_lists[idx])
        # If equal, do nothing
    for idx, list in enumerate(factor_lists):
        if idx%2 == 1:
            factor_lists[idx] = factor_lists[idx][::-1] 

    factor_lists = np.array(factor_lists).T
    prod_block_sizes = np.ones((len(factor_lists)+1, len(factor_lists[0])), dtype = int)
    prod_block_sizes[1:-1] = np.cumprod(factor_lists[-1:0:-1], axis =0)[::-1]
    prod_block_sizes[0] = prod_block_sizes[0] * 1e100

    return factor_lists, prod_block_sizes

@time_function
def hierarchical_block_indexing(index, prod_block_sizes):
    return np.floor(np.mod(index.reshape([1]+list(index.shape)), prod_block_sizes[:-1].reshape(list(prod_block_sizes[:-1].shape)+[1]*(prod_block_sizes.shape[1])))/prod_block_sizes[1:].reshape(list(prod_block_sizes[1:].shape)+[1]*(prod_block_sizes.shape[1]))).astype(int)

def gen_encoding_map(shape):
    dim = len(shape)
    block_sizes, prod_blocks = get_factorlist(shape)
    indices_all = np.indices(shape)
    mapped_indexes = hierarchical_block_indexing(indices_all, prod_blocks)
    final_map = np.empty([len(block_sizes)]+list(shape))
    for i in range(len(block_sizes)):
        final_map[i] = np.ravel_multi_index(mapped_indexes[i], block_sizes[i])
    return np.prod(block_sizes, axis= 1), final_map.astype(int)

In [8]:
class NDMPS_new:
    def __init__(self, mps=None, qubit_size=None, encoding_map=None, norm=True, mode="Std", min_value = 0, max_value = 1):
        self.qubit_size = qubit_size
        self.encoding_map = encoding_map
        self.mps = mps
        self.norm = norm #Normalize matrix data
        #Compression mode 
        # "Std" standard Block Encoding
        # "DCT" discrete cosine fourier transform before compression
        self.mode = mode 
        self.min_value = min_value
        self.max_value = max_value
    
    @classmethod
    # @time_function
    def from_matrix(cls, matrix, norm = False, mode = "Std"):
        qubit_size, encoding_map = gen_encoding_map(matrix.shape)
        encoding_map = np.moveaxis(encoding_map, 0, -1)

        #check for flags
        if norm:
            matrix = matrix / (np.linalg.norm(matrix))
        if mode == "DCT":
            matrix = dct(matrix, norm = "ortho")

        #initialize tensor
        contracted_tensor = np.empty(shape = tuple(qubit_size))


        #encode matrix data
        # start_nested_loop = time.time()
        it = np.nditer(matrix, flags=['multi_index'])
        for _ in it:
            contracted_tensor[tuple(encoding_map[it.multi_index])] = matrix[it.multi_index]
        
        # nested_loop_time = time.time() - start_nested_loop
        # print(f"Time for nested loops: {nested_loop_time:.4f} seconds")
        #put in MPS
        # start_mps_creation = time.time()
        mps = qtn.MatrixProductState.from_dense(contracted_tensor, dims = tuple(qubit_size))
        # mps_creation_time = time.time() - start_mps_creation
        # print(f"Time to create MPS from dense tensor: {mps_creation_time:.4f} seconds")
        #return class
        return cls(mps, qubit_size, encoding_map, norm, mode, np.min(matrix), np.max(matrix))

    # @time_function
    def compression_ratio(self):
        initial_N = np.prod(self.qubit_size)
        compressed_N = self.number_elements_in_MPS()
        # TODO: also implement the compression rate in bits / bits
        return compressed_N / initial_N
        
    # @time_function
    def compress(self, cutoff):
        """
        Compresses a Matrix Product State (MPS) by cutting bonds with a relative cutoff value.
        Arguments:
            cutoff (float): The relative cutoff value to use for bond compression.
        Returns:
            None
        """
        size = len(self.mps.sites)
        for i in np.arange(1, size):
            t1 = self.mps[i-1] # Tensor 1
            t2 = self.mps[i] # Tensor 2
            # Compress bond according to percentage * bond dimension
            qtn.tensor_compress_bond(t1, t2, cutoff = cutoff, cutoff_mode = "rel") 
    def continuous_compress(self, cutoff, print_ratio = True):
        compress_list = np.array([0.01, 0.05, 0.1, 0.2, 0.5, 0.8, 1]) * cutoff
        for c in compress_list:
            self.compress(c)
            if print_ratio:
                print(f"Compression ratio at {c}: {self.compression_ratio()}")


    # @time_function
    def number_elements_in_MPS(self):
        """
        Returns the number of tensor elements in the quimb MPS.
        Parameters:
            mps: quimb MatrixProductState object
        Returns:
            int: The total number of tensor elements in the MPS."""
        return sum(t.size for t in self.mps)
    
    # @time_function
    def mps_to_matrix(self):
        """
        Converts the compressed Matrix Product State (MPS) representation back to an image matrix.
        Arguments:
            None
        Returns:
            Compressed matrix
        """

        #conract mps
        contracted_mps = self.mps ^ ...

        #order tensor legs back
        for i in np.arange(len(contracted_mps.inds)):
            contracted_mps.moveindex("k"+str(i), i, inplace=True)
        
        #return in correct format

        recovered_tensor = np.empty(self.encoding_map.shape)
        it = np.nditer(recovered_tensor, flags=['multi_index'])
        for _ in it:
            recovered_tensor[it.multi_index] = contracted_mps.data[self.encoding_map[it.multi_index]]
        
        if self.mode == "Std":
            return recovered_tensor
        elif self.mode == "DCT":
            return idct(recovered_tensor, norm = "ortho")

In [9]:
test_tens = np.random.rand(240,240,220)

In [11]:
mps_test = NDMPS_new.from_matrix(test_tens)

  prod_block_sizes[0] = prod_block_sizes[0] * 1e100


Time to run hierarchical_block_indexing: 1.4560 seconds


In [12]:
mps_old = NDMPS.from_matrix(test_tens)

  prod_block_sizes[0] = prod_block_sizes[0] * 1e100


Time for nested loops: 1739022690.9270 seconds
Time to create MPS from dense tensor: 2.4463 seconds


In [33]:
mps_old.mps @ mps_test.mps

4223506.728076902

In [34]:
mps_test.mps @ mps_test.mps

4223506.728076903

In [19]:
qubit_size, encoding_map = gen_encoding_map(test_tens.shape)
encoding_map = np.moveaxis(encoding_map, 0, -1)

contracted_tensor = np.empty(shape = tuple(qubit_size))
print(contracted_tensor.shape)

  prod_block_sizes[0] = prod_block_sizes[0] * 1e100


Time to run hierarchical_block_indexing: 1.3211 seconds
(30, 32, 80, 165)


In [20]:
encoding_map.shape

(240, 240, 220, 4)

In [35]:
shape_before = test_tens.shape
k = encoding_map.shape[-1]
flat_data = test_tens.flatten()
flat_new_indices = encoding_map.reshape(-1, k).astype(int)
new_shape = [flat_new_indices[:, dim].max() + 1 for dim in range(k)]
new_tensor = np.zeros(new_shape, dtype=test_tens.dtype)
indices = tuple(flat_new_indices[:, dim] for dim in range(k))
new_tensor[indices] = flat_data

In [39]:
mps_try = qtn.MatrixProductState.from_dense(new_tensor, dims = tuple(qubit_size))

In [40]:
mps_test.mps @ mps_try

4223506.728076902

In [None]:
4223506.728076902

[np.int64(30), np.int64(32), np.int64(80), np.int64(165)]

In [92]:
class NDMPS_speed:
    def __init__(self, mps=None, qubit_size=None, encoding_map=None, norm=True, mode="Std", min_value = 0, max_value = 1):
        self.qubit_size = qubit_size
        self.encoding_map = encoding_map
        self.mps = mps
        self.norm = norm #Normalize matrix data
        #Compression mode 
        # "Std" standard Block Encoding
        # "DCT" discrete cosine fourier transform before compression
        self.mode = mode 
        self.min_value = min_value
        self.max_value = max_value
    
    @classmethod
    # @time_function
    def from_matrix(cls, matrix, norm = False, mode = "Std"):
        qubit_size, encoding_map = gen_encoding_map(matrix.shape)
        encoding_map = np.moveaxis(encoding_map, 0, -1)

        #check for flags
        if norm:
            matrix = matrix / (np.linalg.norm(matrix))
        if mode == "DCT":
            matrix = dct(matrix, norm = "ortho")

        #initialize tensor
        contracted_tensor = np.empty(shape = tuple(qubit_size), dtype=matrix.dtype)

        # rearange the data
        k = encoding_map.shape[-1]
        flat_data = matrix.flatten()
        flat_new_indices = encoding_map.reshape(-1, k).astype(int)
        new_shape = [flat_new_indices[:, dim].max() + 1 for dim in range(k)]
        indices = tuple(flat_new_indices[:, dim] for dim in range(k))
        contracted_tensor[indices] = flat_data
        
        nested_loop_time = time.time()
        print(f"Time for nested loops: {nested_loop_time:.4f} seconds")
        #put in MPS
        start_mps_creation = time.time()
        mps = qtn.MatrixProductState.from_dense(contracted_tensor, dims = tuple(qubit_size))
        mps_creation_time = time.time() - start_mps_creation
        print(f"Time to create MPS from dense tensor: {mps_creation_time:.4f} seconds")
        #return class
        return cls(mps, qubit_size, encoding_map, norm, mode, np.min(matrix), np.max(matrix))

    # @time_function
    def compression_ratio(self):
        initial_N = np.prod(self.qubit_size)
        compressed_N = self.number_elements_in_MPS()
        # TODO: also implement the compression rate in bits / bits
        return compressed_N / initial_N
        
    # @time_function
    def compress(self, cutoff):
        """
        Compresses a Matrix Product State (MPS) by cutting bonds with a relative cutoff value.
        Arguments:
            cutoff (float): The relative cutoff value to use for bond compression.
        Returns:
            None
        """
        size = len(self.mps.sites)
        for i in np.arange(1, size):
            t1 = self.mps[i-1] # Tensor 1
            t2 = self.mps[i] # Tensor 2
            # Compress bond according to percentage * bond dimension
            qtn.tensor_compress_bond(t1, t2, cutoff = cutoff, cutoff_mode = "rel") 
    def continuous_compress(self, cutoff, print_ratio = True):
        compress_list = np.array([0.01, 0.05, 0.1, 0.2, 0.5, 0.8, 1]) * cutoff
        for c in compress_list:
            self.compress(c)
            if print_ratio:
                print(f"Compression ratio at {c}: {self.compression_ratio()}")


    # @time_function
    def number_elements_in_MPS(self):
        """
        Returns the number of tensor elements in the quimb MPS.
        Parameters:
            mps: quimb MatrixProductState object
        Returns:
            int: The total number of tensor elements in the MPS."""
        return sum(t.size for t in self.mps)
    
    # @time_function
    def mps_to_matrix(self):
        """
        Converts the compressed Matrix Product State (MPS) representation back to an image matrix.
        Arguments:
            None
        Returns:
            Compressed matrix
        """

        #conract mps
        contracted_mps = self.mps ^ ...

        #order tensor legs back
        for i in np.arange(len(contracted_mps.inds)):
            contracted_mps.moveindex("k"+str(i), i, inplace=True)
        
        #return in correct format
        k = self.encoding_map.shape[-1]
        
        recovered_tensor = np.empty(self.encoding_map.shape)
        contracted_mps = contracted_mps.data
        recovered_tensor = contracted_mps[tuple(self.encoding_map[..., dim] for dim in range(k))]
        
        if self.mode == "Std":
            return recovered_tensor
        elif self.mode == "DCT":
            return idct(recovered_tensor, norm = "ortho")

In [87]:
mps_speed = NDMPS_speed.from_matrix(test_tens)

  prod_block_sizes[0] = prod_block_sizes[0] * 1e100


Time to run hierarchical_block_indexing: 1.2963 seconds
Time for nested loops: 1737928377.1835 seconds
Time to create MPS from dense tensor: 2.2617 seconds


In [88]:
mps_speed.mps @ mps_test.mps

4223506.728076902

In [89]:
con_tens = mps_speed.mps ^ ...

In [90]:
type(con_tens.data)

numpy.ndarray

In [93]:
mps_speed.mps_to_matrix()

array([[[0.43384858, 0.12321303, 0.92785341, ..., 0.90405203,
         0.54961228, 0.07769605],
        [0.56101817, 0.09163379, 0.05478229, ..., 0.1001284 ,
         0.93329716, 0.18804416],
        [0.17072938, 0.46400181, 0.76266317, ..., 0.35617674,
         0.72476931, 0.93450828],
        ...,
        [0.6809151 , 0.00982142, 0.91035291, ..., 0.76871745,
         0.76100618, 0.64403658],
        [0.88386582, 0.57398938, 0.19550277, ..., 0.74408306,
         0.40328463, 0.95067469],
        [0.01815623, 0.07999099, 0.65302307, ..., 0.75318321,
         0.29115235, 0.11161086]],

       [[0.5642255 , 0.25830904, 0.56575468, ..., 0.30530762,
         0.97005248, 0.27493029],
        [0.38114399, 0.02276837, 0.5884885 , ..., 0.17433811,
         0.01266093, 0.09198046],
        [0.57487109, 0.01311029, 0.46535928, ..., 0.04451448,
         0.91478219, 0.33848192],
        ...,
        [0.38537621, 0.4245431 , 0.39818488, ..., 0.58143511,
         0.43378469, 0.28558417],
        [0.0

In [96]:
np.max(test_tens - mps_speed.mps_to_matrix())

np.float64(9.409584222908052e-12)

In [49]:
import numpy as np

# Example dimensions
A, B, C = 2, 3, 4
k = 2  # Number of new dimensions

# Original data tensor (for demonstration)
data = np.arange(A * B * C).reshape(A, B, C)
print("Original Data:\n", data)

# Encoding map tensor with new indices
encoding_map = np.stack((data % 3, data // 3), axis=-1)  # Shape: (A, B, C, 2)
print("Encoding Map:\n", encoding_map)

# Create new_tensor based on encoding_map (simulating the forward mapping)
new_shape = [encoding_map[..., dim].max() + 1 for dim in range(k)]
new_tensor = np.zeros(new_shape, dtype=data.dtype)
flat_data = data.flatten()
flat_new_indices = encoding_map.reshape(-1, k)
indices = tuple(flat_new_indices[:, dim] for dim in range(k))
new_tensor[indices] = flat_data
print("New Tensor:\n", new_tensor)

# Reverse mapping: reconstruct the original data
reconstructed_data = new_tensor[tuple(encoding_map[..., dim] for dim in range(k))]
print("Reconstructed Data:\n", reconstructed_data)

# Verify correctness
assert np.array_equal(data, reconstructed_data), "Reconstruction failed!"
print("Reconstruction successful!")

Original Data:
 [[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]

 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]]
Encoding Map:
 [[[[0 0]
   [1 0]
   [2 0]
   [0 1]]

  [[1 1]
   [2 1]
   [0 2]
   [1 2]]

  [[2 2]
   [0 3]
   [1 3]
   [2 3]]]


 [[[0 4]
   [1 4]
   [2 4]
   [0 5]]

  [[1 5]
   [2 5]
   [0 6]
   [1 6]]

  [[2 6]
   [0 7]
   [1 7]
   [2 7]]]]
New Tensor:
 [[ 0  3  6  9 12 15 18 21]
 [ 1  4  7 10 13 16 19 22]
 [ 2  5  8 11 14 17 20 23]]
Reconstructed Data:
 [[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]

 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]]
Reconstruction successful!


In [54]:
reconstructed_data = new_tensor[tuple(encoding_map[..., dim] for dim in range(k))]
print("Reconstructed Data:\n", reconstructed_data)

Reconstructed Data:
 [[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]

 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]]


In [55]:
reconstructed_data.shape

(2, 3, 4)

In [56]:
encoding_map.shape

(2, 3, 4, 2)