# Matrix Product State (MPS) with Compression Layer

This code implements a training procedure for a Matrix Product State (MPS) model with a compression layer, following the algorithm described in a research paper [Generative Learning of Continuous Data by Tensor Networks](https://arxiv.org/abs/2310.20498). The MPS efficiently models high-dimensional data, and the compression layer reduces the dimensionality of the input data before feeding it into the MPS, improving computational efficiency.

**Key Components:**

- **`contract_mps_except_i` Function**: Contracts the MPS with compressed input data, excluding a specified site.
- **`train_compression_layer` Function**: Trains the compression matrices using the Procrustes problem.
- **Procrustes Problem**: Solved using Singular Value Decomposition (SVD) to update the compression matrices.

In this implementation, we provide detailed explanations and step-by-step computations to facilitate understanding and future modifications.


In [None]:
import torch
import torch.nn as nn

### `contract_mps_except_i` Function

This function computes the vector v_{i,j} for a given sample  x^{(j)} and site i  by contracting the MPS with compressed input data, excluding site i .

**Process:**

1. **Compress Input Data**: For each site  $n \neq i$, compress the input data $x_n^{(j)}$ using the compression matrix $U_n$:
   
   $\tilde{x}_n^{(j)} = U_n^T x_n^{(j)}$
   

2. **Contract with MPS Tensors**: Contract each compressed slice with its corresponding MPS tensor over the physical dimension.

3. **Sequential Contraction**: Sequentially contract the MPS tensors over the bond dimensions, skipping the contraction over the physical dimension at site i .

4. **Final Contraction**: After contracting all sites, the result is a vector $v_{i,j}$ of shape equal to the compressed physical dimension $d$ , representing the MPS-contracted embedding excluding site $i$.


This vector $v_{i,j}$ is then used in the training process to update the compression matrix $U_i$ via the Procrustes problem.


In [None]:
def contract_mps_except_i(mps, U_list, input_data, exclude_index):
    """
    Contracts the MPS with compressed input data excluding site i.
    Returns a vector of shape (phys_compressed,) corresponding to v_{i,j}.
    """
    N = len(mps)
    # Step 1: Compress all data slices except at the excluded index
    compressed_slices = []
    for idx in range(N):
        if idx != exclude_index:
            data_slice = input_data[idx]          # Shape: (physical_dim_data,)
            U_i = U_list[idx]                     # Shape: (physical_dim_data, phys_compressed)
            compressed_slice = data_slice @ U_i   # Shape: (phys_compressed,)
            compressed_slices.append(compressed_slice)
        else:
            compressed_slices.append(None)  # Placeholder for the excluded index

    # Step 2: Contract compressed slices with MPS tensors over the physical dimension
    mps_contracted_list = []
    for idx in range(N):
        mps_tensor = mps[idx]  # Shape: (bond_dim_left, phys_compressed, bond_dim_right)

        if idx != exclude_index:
            compressed_slice = compressed_slices[idx]  # Shape: (phys_compressed,)
            # Contract over the physical dimension (dimension 1)
            mps_contracted = torch.tensordot(mps_tensor, compressed_slice, dims=([1], [0]))
            # Resulting shape: (bond_dim_left, bond_dim_right)
        else:
            # Keep the MPS tensor uncontracted over the physical dimension
            mps_contracted = mps_tensor  # Shape: (bond_dim_left, phys_compressed, bond_dim_right)

        mps_contracted_list.append(mps_contracted)

    # Step 3: Sequentially contract MPS tensors over the bond dimensions
    result = mps_contracted_list[0]
    for idx in range(1, N):
        next_tensor = mps_contracted_list[idx]

        if idx == exclude_index:
            # Contract over bond dimensions
            result = torch.tensordot(result, next_tensor, dims=([-1], [0]))
            # Resulting shape: (1, phys_compressed, bond_dim_right)
        else:
            # Contract over bond dimensions
            result = torch.tensordot(result, next_tensor, dims=([-1], [0]))
            # Resulting shape updates accordingly

    # After sequential contractions, result should have shape (1, phys_compressed, 1)
    # Squeeze to get (phys_compressed,)
    v_i_j = result.squeeze()
    return v_i_j  # Shape: (phys_compressed,)

### Training the Compression Layer (`train_compression_layer` Function)

The `train_compression_layer` function trains the compression matrices $\{ U_i \}$ using the input data and the MPS.

**Algorithm Overview:**

For each site $i$ and over multiple epochs:

1. **Compute  $u_{i,j}$ and $v_{i,j}$**:
   - $u_{i,j} = x_i^{(j)}$: The input feature vector at site $i$ for sample $j$.
   - $v_{i,j}$: The MPS-contracted embedding excluding site $i$, computed using `contract_mps_except_i`.

2. **Compute Inner Product $c_j$**:
   
   $c_j = u_{i,j}^T U_i v_{i,j}$

3. **Calculate Magnitude and Phase**:
   - Magnitude:
     
     $p_j = |c_j| + \delta$

     where $\delta$ is a small constant to prevent division by zero.
   - Phase:
     
     $\phi_j = \frac{c_j}{p_j}$

4. **Accumulate Negative Log-Likelihood Loss**:
   
   $\text{NLL Loss} = -\sum_j \log(p_j)$

5. **Update Compression Matrix $U_i$**:
   - Formulate the matrix $B$:
     
     $B = \sum_j \left( p_j^\epsilon \phi_j \right)^{-1} u_{i,j} v_{i,j}^T$
     where $\epsilon$ is a stability parameter adjusted during training.
   - Solve the Procrustes problem using SVD to update $U_i$:
     
     $U_i = U_{\text{B}} V_{\text{B}}^T$
     where $U_{\text{B}}$ and $V_{\text{B}}$ come from the SVD of $B$:
     
     $B = U_{\text{B}} \Sigma_{\text{B}} V_{\text{B}}^T$
     

**Key Points:**

- The compression matrices $U_i$ are updated to minimize the negative log-likelihood loss.
- The use of SVD ensures that $U_i$ remains an isometric (orthogonal) matrix.
- The stability parameter $\epsilon$ is adjusted during training to ensure convergence.


In [None]:
def train_compression_layer(input_data, mps, U_list, num_tensors, bond_dim, phys_compressed, epochs=5):
    """
    Train the compression layer for an MPS using a set of input data over multiple epochs.
    Implements the algorithm as per the provided pseudocode from the paper.
    """
    num_samples = input_data.shape[0]
    physical_dim_data = input_data.shape[2]
    epsilon = 0.5  # Controls stability, will be adjusted

    for epoch in range(epochs):
        total_nll_loss = 0
        print(f"Epoch {epoch+1}/{epochs}")
        for tensor_index in range(num_tensors):  # Loop over each site
            u_list = []  # Original feature embeddings (u_{i,j})
            v_list = []  # MPS-contracted embeddings excluding site i (v_{i,j})
            c_list = []  # Inner products c_j
            p_list = []  # Magnitudes |c_j|
            phi_list = []  # Phases φ_j

            for sample_idx in range(num_samples):  # Loop over each sample in the batch
                x_j = input_data[sample_idx]       # Shape (N, physical_dim_data)
                u_i_j = x_j[tensor_index]          # Shape (physical_dim_data,)

                # Compute v_{i,j}
                v_i_j = contract_mps_except_i(mps, U_list, x_j, tensor_index)  # Shape (phys_compressed,)

                # Compute c_j = u_i_j^T U_i v_i_j
                U_i = U_list[tensor_index]  # Shape (physical_dim_data, phys_compressed)
                temp = U_i.T @ u_i_j        # Shape (phys_compressed,)

                c_j = temp @ v_i_j          # Scalar

                p_j = torch.abs(c_j) + 1e-20  # Magnitude
                phi_j = c_j / p_j            # Phase

                # Store intermediate results
                u_list.append(u_i_j)
                v_list.append(v_i_j)
                c_list.append(c_j)
                p_list.append(p_j)
                phi_list.append(phi_j)

                # Accumulate NLL loss
                nll_loss = -torch.log(p_j)
                total_nll_loss += nll_loss

            # Update U_i using Procrustes problem
            # Compute B = Σ_j (p_j^ε φ_j)^{-1} u_{i,j} v_{i,j}^T
            B = torch.zeros((physical_dim_data, phys_compressed), dtype=torch.float32)
            for j in range(num_samples):
                weight = ((p_list[j] ** epsilon) ** (-1)) * phi_list[j]
                B += weight.real * torch.outer(u_list[j], v_list[j])

            # Perform SVD on B
            try:
                B_U, _, B_Vh = torch.linalg.svd(B, full_matrices=False)
                # Update U_i = B_U @ B_Vh^T
                U_i_new = B_U @ B_Vh
                U_list[tensor_index] = nn.Parameter(U_i_new)
            except Exception as e:
                print(f"SVD failed at tensor index {tensor_index}: {e}")
                continue  # Skip updating this U_i if SVD fails

            # Adjust epsilon
            epsilon = min(1.0, epsilon + 0.05)
            if not torch.isfinite(U_list[tensor_index]).all():
                epsilon = max(0.0, epsilon - 0.05)
        print(f"Total NLL Loss after epoch {epoch+1}: {total_nll_loss.item()}")
    return U_list

## Implementation

Below is the implementation of the MPS with a compression layer, including the functions and training loop as described. The code follows the algorithm outlined in the previous sections, with detailed comments and explanations.

**Note**: The code uses PyTorch for tensor operations and assumes that you have PyTorch installed in your environment.


In [None]:
# Parameters
phys_compressed = 2  # Compressed dimension d
bond_dim = 4         # Bond dimension χ
N = 10               # Number of tensors/sites (reduced for simplicity)
physical_dim_data = 2  # Original physical dimension D
num_samples = 100    # Number of data samples
epochs = 15         # Number of epochs for training

# Initialize MPS
mps = []
for i in range(N):
    if i == 0:
        mps_tensor = torch.randn(1, phys_compressed, bond_dim) / N  # Shape (1, d, χ)
    elif i == N - 1:
        mps_tensor = torch.randn(bond_dim, phys_compressed, 1) / N  # Shape (χ, d, 1)
    else:
        mps_tensor = torch.randn(bond_dim, phys_compressed, bond_dim) / N  # Shape (χ, d, χ)
    mps.append(mps_tensor)

# Initialize U_list
# Initialize U_list with QR decomposition of a random matrix
U_list = []
for _ in range(N):
    random_matrix = torch.randn(physical_dim_data, phys_compressed)
    q, _ = torch.qr(random_matrix)
    U_list.append(nn.Parameter(q))
    
# Generate input_data using the identity matrix
# For simplicity, we'll create samples where each feature vector is an identity vector
input_data = torch.zeros(num_samples, N, physical_dim_data)
for i in range(num_samples):
    for j in range(N):
        input_data[i, j] = torch.tensor([1.0, 0.0])  # Use a fixed vector for all features

# Train the compression layer over multiple epochs
U_list = train_compression_layer(input_data, mps, U_list, N, bond_dim, phys_compressed, epochs=epochs)

print("Training completed.")
print("U_list[0]:")
print(U_list[0].data)