In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from torch.optim.lr_scheduler import StepLR
import itertools
import pickle

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu") # For testing purposes, using CPU
print(device)

In [None]:
class Projective_Norm():
    '''
    Class to perform the optimization for calculating the minimum nuclear norm decomposition of a tensor

    Parameters:
    -----------
    nature: str : nature of the tensor (symmetric or unsymmetric)
    dim: tuple : dimensions of the tensor
    device: torch.device : device to perform the optimization
    type: str : type of the tensor (real or complex)
    form: str : form of the input tensor (as a tensor or as a density matrix)

    Methods:
    --------
    is_symmetric_tensor: Check if the tensor is symmetric
    Parameter_Init: Initialize the parameters for the optimization
    Reconstructed_State: Reconstruct the tensor from the parameters to get T' using einsum
    Norm_Cost: Calculate the cost of the nuclear norm of the tensor. Minimum value of this cost is the nuclear norm of the tensor
    Reconstruction_Cost: Calculate the cost of the reconstruction of the tensor. This cost is the mean squared error between the original tensor and the reconstructed tensor
    Coefficent_Loss: Calculate the cost of the coefficents of the tensor. This is a square of number of terms greater than a threshold to make the coefficents sparse
    Nuclear_Rank: Calculate the nuclear rank of the tensor. This is the number of non-zero coefficents in the tensor
    forward: Calculate the total loss of the optimization
    '''

    def __init__(self, nature, dim, device, type, form):
        '''
        Initialize the parameters for the optimization

        Parameters:
        -----------
        nature: str : nature of the tensor (symmetric or unsymmetric)
        dim: tuple : dimensions of the tensor
        device: torch.device : device to perform the optimization
        type: str : type of the tensor (real or complex)
        form: str : form of the input tensor (as a tensor or as a density matrix)

        Initializes:
        ------------
        order: int : order of the tensor
        dim_total: int : total dimnsionality of the tensor
        r_max: int : maximum rank of the tensor

        Returns:
        --------
        None
        '''

        self.nature = nature
        # Check if the nature string is valid
        if self.nature not in ["symmetric", "unsymmetric"]:
            raise ValueError("Invalid nature string. Choose 'symmetric' or 'unsymmetric'")
        
        self.dim = dim
        self.order = np.size(dim)
        self.device = device
        self.type = type
        # Check if the type string is valid
        if self.type not in ["real", "complex"]:
            raise ValueError("Invalid type string. Choose 'real' or 'complex'")
        
        self.form = form
        # Check if the form string is valid
        if self.form not in ["tensor", "density_matrix"]:
            raise ValueError("Invalid form string. Choose 'tensor' or 'density_matrix'")
        
        self.dim_total = np.prod(self.dim)

        if self.form == "tensor":
            self.r_max = int(self.dim_total/np.max(self.dim))
        elif self.form == "density_matrix":
            self.r_max = int(self.dim_total/np.max(self.dim))**2


    def is_symmetric_tensor(self, T):
        '''
        Check if the tensor is symmetric

        Parameters:
        -----------
        T: torch.tensor : tensor to check for symmetry

        Returns:
        --------
        None: Raises an error if the tensor is not symmetric
        '''

        for i in range(self.order):
            for j in range(i+1, self.order):
                if self.dim[i] != self.dim[j]:
                    raise ValueError("Input tensor is not symmetric as the dimensions are not equal")

        indices = list(range(self.order))
        for perm in itertools.permutations(indices):
            if not torch.allclose(T, T.permute(perm)):
                raise ValueError("Input tensor is not symmetric as it is not invariant under permutation of indices")
            

    def Parameter_Init(self):
        '''
        Initialize the parameters for the optimization

        Parameters:
        -----------
        None

        Returns:
        --------
        Param_dict: dict : dictionary of the parameters for the optimization
        '''

        Param_dict = {}

        if self.form == "tensor":
            if self.type == "real":
                Param_dict["K"] = torch.randn(self.r_max, requires_grad=True, device=self.device)

            elif self.type == "complex":
                Param_dict["K"] = torch.randn((self.r_max, 2), requires_grad=True, device=self.device)

            if self.type == "real":
                if self.nature == "unsymmetric":        
                    for i in range(self.order):
                        Param_dict["x"+str(i+1)] = torch.randn((self.dim[i], self.r_max), requires_grad=True, device=self.device)

                elif self.nature == "symmetric":
                    Param_dict["x"] = torch.randn((self.dim[0], self.r_max), requires_grad=True, device=self.device)

            elif self.type == "complex":
                if self.nature == "unsymmetric":
                    for i in range(self.order):
                        Param_dict["x"+str(i+1)] = torch.randn((self.dim[i], self.r_max, 2), requires_grad=True, device=self.device)

                elif self.nature == "symmetric":
                    Param_dict["x"] = torch.randn((self.dim[0], self.r_max, 2), requires_grad=True, device=self.device)

        elif self.form == "density_matrix":
            if self.type == "real":
                Param_dict["K"] = torch.randn(self.r_max, requires_grad=True, device=self.device)
                for i in range(self.order):
                    Param_dict["x"+str(i+1)] = torch.randn((self.dim[i], self.r_max), requires_grad=True, device=self.device) # Ket initialization
                    Param_dict["y"+str(i+1)] = torch.randn((self.dim[i], self.r_max), requires_grad=True, device=self.device) # Bra initialization

            elif self.type == "complex":
                Param_dict["K"] = torch.randn(self.r_max, requires_grad=True, device=self.device)
                for i in range(self.order):
                    Param_dict["x"+str(i+1)] = torch.randn((self.dim[i], self.r_max, 2), requires_grad=True, device=self.device)
                    Param_dict["y"+str(i+1)] = torch.randn((self.dim[i], self.r_max, 2), requires_grad=True, device=self.device)                

        return Param_dict



    def Reconstructed_State(self, Param_dict):
        '''
        Reconstruct the tensor from the parameters to get T' using einsum. Depending on the form of the input tensor, the reconstructed tensor is returned as either a tensor or a density matrix.
        
        Parameters:
        -----------
        Param_dict: dict : dictionary of the parameters for the optimization

        Returns:
        --------
        T_recon: torch.tensor : reconstructed tensor T'
        '''
        
        if self.form == "tensor":
            if self.type == "real":
                T_recon = torch.zeros(self.dim, dtype=torch.float32, device=self.device)
                
                if self.nature == "unsymmetric":
                    for r in range(self.r_max):
                        phi = torch.einsum('i,j->ij', Param_dict["x1"][:, r], Param_dict["x2"][:, r])
                        for i in range(2, self.order):
                            phi = torch.einsum('...ij, k->...ijk', phi, Param_dict["x"+str(i+1)][:, r])
                        T_recon = T_recon + Param_dict["K"][r]*(phi/torch.norm(phi))

                elif self.nature == "symmetric":
                    for r in range(self.r_max):
                        phi = torch.einsum('i,j->ij', Param_dict["x"][:, r], Param_dict["x"][:, r])
                        for i in range(2, self.order):
                            phi = torch.einsum('...ij, k->...ijk', phi, Param_dict["x"][:, r])
                        T_recon = T_recon + Param_dict["K"][r]*(phi/torch.norm(phi))

            elif self.type == "complex":
                T_recon = torch.zeros(self.dim, dtype=torch.complex64, device=self.device)

                if self.nature == "unsymmetric":
                    for r in range(self.r_max):
                        phi = torch.einsum('i,j->ij', Param_dict["x1"][:, r, 0] + 1j*Param_dict["x1"][:, r, 1], Param_dict["x2"][:, r, 0] + 1j*Param_dict["x2"][:, r, 1])
                        for i in range(2, self.order):
                            phi = torch.einsum('...ij, k->...ijk', phi, Param_dict["x"+str(i+1)][:, r, 0] + 1j*Param_dict["x"+str(i+1)][:, r, 1])
                        T_recon = T_recon + (Param_dict["K"][r, 0] + 1j*Param_dict["K"][r, 1])*(phi/torch.norm(phi))

                elif self.nature == "symmetric":
                    for r in range(self.r_max):
                        phi = torch.einsum('i,j->ij', Param_dict["x"][:, r, 0] + 1j*Param_dict["x"][:, r, 1], Param_dict["x"][:, r, 0] + 1j*Param_dict["x"][:, r, 1])
                        for i in range(2, self.order):
                            phi = torch.einsum('...ij, k->...ijk', phi, Param_dict["x"][:, r, 0] + 1j*Param_dict["x"][:, r, 1])
                        T_recon = T_recon + (Param_dict["K"][r, 0] + 1j*Param_dict["K"][r, 1])*(phi/torch.norm(phi))

        elif self.form == "density_matrix":
            if self.type == "real":
                T_recon = torch.zeros((self.dim_total, self.dim_total), dtype=torch.float32, device=self.device)

                for r in range(self.r_max):
                    phi_x = torch.einsum('i,j->ij', Param_dict["x1"][:, r], Param_dict["x2"][:, r])
                    phi_y = torch.einsum('i,j->ij', Param_dict["y1"][:, r], Param_dict["y2"][:, r])
                    for i in range(2, self.order):
                        phi_x = torch.einsum('...ij, k->...ijk', phi_x, Param_dict["x"+str(i+1)][:, r])
                        phi_y = torch.einsum('...ij, k->...ijk', phi_y, Param_dict["y"+str(i+1)][:, r])
                    phi_x_flat = phi_x.view(-1, 1)
                    phi_y_flat = phi_y.view(-1, 1)
                    rho = phi_x_flat @ phi_y_flat.T
                    T_recon = T_recon + Param_dict["K"][r]*(rho/torch.norm(rho, p='fro'))

            elif self.type == "complex":
                T_recon = torch.zeros((self.dim_total, self.dim_total), dtype=torch.complex64, device=self.device)

                for r in range(self.r_max):
                    phi_x = torch.einsum('i,j->ij', Param_dict["x1"][:, r, 0] + 1j*Param_dict["x1"][:, r, 1], Param_dict["x2"][:, r, 0] + 1j*Param_dict["x2"][:, r, 1])
                    phi_y = torch.einsum('i,j->ij', Param_dict["y1"][:, r, 0] + 1j*Param_dict["y1"][:, r, 1], Param_dict["y2"][:, r, 0] + 1j*Param_dict["y2"][:, r, 1])
                    for i in range(2, self.order):
                        phi_x = torch.einsum('...ij, k->...ijk', phi_x, Param_dict["x"+str(i+1)][:, r, 0] + 1j*Param_dict["x"+str(i+1)][:, r, 1])
                        phi_y = torch.einsum('...ij, k->...ijk', phi_y, Param_dict["y"+str(i+1)][:, r, 0] + 1j*Param_dict["y"+str(i+1)][:, r, 1])
                    phi_x_flat = phi_x.view(-1, 1)
                    phi_y_flat = phi_y.view(-1, 1)
                    rho = phi_x_flat @ phi_y_flat.T.conj()
                    T_recon = T_recon + Param_dict["K"][r]*(rho/torch.norm(rho, p='fro'))
        
        return T_recon
    


    def Norm_Cost(self, Param_dict):
        '''
        Calculate the cost of the nuclear norm of the tensor. Minimum value of this cost is the nuclear norm of the tensor

        Parameters:
        -----------
        Param_dict: dict : dictionary of the parameters for the optimization

        Returns:
        --------
        norm: torch.tensor : nuclear norm of the tensor
        '''

        norm = 0

        if self.form == "tensor":
            if self.type == "real":
                if self.nature == "unsymmetric":
                    for r in range(self.r_max):
                        norm = norm + torch.abs(Param_dict["K"][r])

                elif self.nature == "symmetric":
                    for r in range(self.r_max):
                        norm = norm + torch.abs(Param_dict["K"][r])

            elif self.type == "complex":
                if self.nature == "unsymmetric":
                    for r in range(self.r_max):
                        norm = norm + torch.norm(Param_dict["K"][r, 0] + 1j*Param_dict["K"][r, 1])

                elif self.nature == "symmetric":
                    for r in range(self.r_max):
                        norm = norm + torch.norm(Param_dict["K"][r, 0] + 1j*Param_dict["K"][r, 1])

        elif self.form == "density_matrix":
            if self.type == "real":
                for r in range(self.r_max):
                    norm = norm + torch.abs(Param_dict["K"][r])

            elif self.type == "complex":
                for r in range(self.r_max):
                    norm = norm + torch.abs(Param_dict["K"][r])

        return norm


    def Reconstruction_Cost(self, T, T_recon):
        '''
        Calculate the cost of the reconstruction of the tensor. This cost is the mean squared error between the original tensor and the reconstructed tensor

        Parameters:
        -----------
        T: torch.tensor : original tensor T
        T_recon: torch.tensor : reconstructed tensor T'

        Returns:
        --------
        loss: torch.tensor : mean squared error between the original tensor and the reconstructed tensor (||T - T'||^2)
        '''
        
        if self.type == "real":
            return F.mse_loss(T, T_recon)
        elif self.type == "complex":
            return F.mse_loss(T.real, T_recon.real) + F.mse_loss(T.imag, T_recon.imag)


    def Coefficent_Loss(self, Param_dict):
        '''
        Calculate the cost of the coefficents of the tensor. This is a square of number of terms greater than a threshold to make the coefficents sparse

        Parameters:
        -----------
        Param_dict: dict : dictionary of the parameters for the optimization

        Returns:
        --------
        torch.tensor : Square of number of terms greater than a threshold
        '''
        if self.type == "real" or self.form == "density_matrix":
            return torch.sum(torch.abs(Param_dict["K"]) > 1e-3)**2
        elif self.type == "complex" and self.form == "tensor":
            return torch.sum(torch.sqrt(Param_dict["K"][:, 0]**2 + Param_dict["K"][:, 1]**2) > 1e-3)**2



    def Nuclear_Rank(self, Param_dict, threshold = 1 * 1e-1):
        '''
        Calculate the nuclear rank of the tensor. This is the number of non-zero coefficents in the tensor

        Parameters:
        -----------
        Param_dict: dict : dictionary of the parameters for the optimization
        threshold: float : threshold to consider a coefficent as non-zero (default = 1e-2)

        Returns:
        --------
        non_zero_weights: int : number of non-zero coefficents in the tensor
        '''

        non_zero_weights = 0
        if self.type == "real" or self.form == "density_matrix":
            non_zero_weights = torch.sum(torch.abs(Param_dict["K"]) > threshold)

        elif self.type == "complex" and self.form == "tensor":
            non_zero_weights = torch.sum((torch.sqrt(Param_dict["K"][:, 0]**2 + Param_dict["K"][:, 1]**2)) > threshold)

        return non_zero_weights

  
    
    def forward(self, T, params):
        '''
        Calculate the total loss of the optimization
        
        Parameters:
        -----------
        T: torch.tensor : original tensor T
        params: dict : dictionary of the parameters for the optimization
        
        Returns:
        --------
        loss: torch.tensor : total loss of the optimization
        '''

        k_1 = 1e12 * self.dim_total # Reconstruction Cost Coefficent
        k_2 = 1e3 * self.dim_total # Coefficent Cost Coefficent
        k_3 = 1e6* self.dim_total # Norm Cost Coefficent
        
        # params = self.Parameter_Init()
        T_recon = self.Reconstructed_State(params)
        coeff_loss = self.Coefficent_Loss(params)
        recon_loss = self.Reconstruction_Cost(T, T_recon)
        norm_loss = self.Norm_Cost(params)

        loss = k_1*recon_loss + k_2*coeff_loss + k_3*norm_loss

        return loss

In [None]:
def InputState(dim, type, form, device):

    d_type = torch.float32
    
    if type == "complex":
        d_type = torch.complex64
    A = torch.zeros(dim, dtype=d_type)

    #### Customize the input state here ####
    
    if form == "density_matrix":
        A = A.view(np.prod(dim), 1) @ A.view(1, np.prod(dim)).conj()
    
    A[0, 0] = 1/np.sqrt(2)
    A[0, 1] = 1/np.sqrt(2)

    ########################################
    return A.to(device)

In [None]:
Loss = []
N_Norm = []

EPOCHS = 100000

In [None]:
type = "complex"
nature = "unsymmetric"
dim = (2,2)
form = "tensor"   # form is either 'tensor' or 'density_matrix'

T = InputState(dim, type, form, device)

projective_norm = Projective_Norm(nature, dim, device, type, form)

Param_dict = projective_norm.Parameter_Init()

if nature == "symmetric" and form == "tensor":
    projective_norm.is_symmetric_tensor(T)

learning_rate = 0.01
optimizer = torch.optim.Adam(Param_dict.values(), lr=learning_rate)
scheduler = StepLR(optimizer, step_size=10000, gamma=0.1)

In [None]:
epoch = 0
pbar = tqdm(range(EPOCHS), desc="Projective Norm convergence")
while epoch < EPOCHS:
    optimizer.zero_grad()
    loss = projective_norm.forward(T, Param_dict)
    loss.backward()
    Loss.append(loss.item())
    N_Norm.append(projective_norm.Norm_Cost(Param_dict).item())

    optimizer.step()

    # If calculated nuclear norm is too large in iteration 0, try another set of parameters and keep repeating the loop until the nuclear norm is small enough
    if epoch == 0 and projective_norm.Norm_Cost(Param_dict).item() > np.prod(dim)/np.max(dim):
        print("Projective norm is too large at epoch 0, trying another set of parameters for faster convergence")
        Param_dict = projective_norm.Parameter_Init()
        optimizer = torch.optim.Adam(Param_dict.values(), lr=learning_rate)
        epoch = 0
        # Go back to the start of the loop
        continue
    else:
        if scheduler.get_last_lr()[0] >= 1e-4:
            scheduler.step()

        if epoch % 100 == 0 or epoch == EPOCHS - 1:
            print(f"Epoch {epoch}, Loss: {loss.item()}, Reconstruction loss: {projective_norm.Reconstruction_Cost(T, projective_norm.Reconstructed_State(Param_dict)).item()}, Projective Norm: {projective_norm.Norm_Cost(Param_dict).item()}, learning rate: {scheduler.get_last_lr()[0]}")

    epoch += 1
    pbar.update(1)
pbar.close()

In [None]:
plt.plot(Loss[1000:])
plt.xlabel("Epochs")
plt.ylabel("Value of the Objective Function")
plt.title("Objective Function vs Epochs for calculating the nuclear norm")
plt.show()

In [None]:
# alpha = 1.5
state = 'separable_state'
dimension = ''
for idx, d in enumerate(dim):
    if idx == len(dim) - 1:
        dimension += str(d)
    else:
        dimension += str(d) + '-'
filename =  './Pickle_files/' + state + '-' + dimension + '-' + type + '-' + nature + '-' + form + '.pkl'

with open(filename, 'wb') as f:
    pickle.dump(N_Norm, f)

In [None]:
analytical_rank = 1
nuclear_rank = projective_norm.Nuclear_Rank(Param_dict)

print(f"Analytical Rank: {analytical_rank}")
print(f"Nuclear Rank: {nuclear_rank}")

analytical_projective_norm = 1
final_projective_norm = N_Norm[-1]

print(f"Analytical Nuclear Norm: {analytical_projective_norm}")
print(f"Calculated Nuclear Norm: {final_projective_norm}")


In [None]:
fig = plt.figure(figsize=(10, 6))
ax = plt.subplot(111)

ax.plot(N_Norm, label="Calculated Projective Norm", color='orange', linestyle='-', linewidth=3)
ax.plot([analytical_projective_norm]*EPOCHS, label="Analytical Projective Norm", color='blue', linestyle='--', linewidth=1.5)

box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])

# Uncomment the following lines to display the analytical and calculated projective norms and nuclear ranks in the plot
ax.text(1.02, 0.5,
        f'Analytical Projective Norm: {round(analytical_projective_norm, 4)}\n\nCalculated Projective Norm: {round(final_projective_norm, 4)}\n------------------------------------------------\nAnalytical Nuclear Rank: {analytical_rank}\n\nCalculated Nuclear Rank: {nuclear_rank}',
        transform=ax.transAxes, fontsize=12, verticalalignment='center')

## Uncomment the following lines to display the analytical and calculated projective norms in the plot
# ax.text(1.02, 0.5, 
#         f'Analytical Projective Norm: {round(analytical_projective_norm, 4)}\n\nCalculated Projective Norm: {round(N_Norm[-1], 4)}', 
#         transform=ax.transAxes, fontsize=12, verticalalignment='center')

## Uncomment the following lines to display the calculated nuclear norm in the plot
# ax.text(1.02, 0.5, 
#         f'Calculated Nuclear Norm: {round(N_Norm[-1], 5)}', 
#         transform=ax.transAxes, fontsize=12, verticalalignment='center')

ax.set_xlabel("Number of Iterations", fontsize=14)
ax.set_ylabel("Value of the Projective Norm", fontsize=14)
fig.suptitle(r"Projective Norm of the 2 qubit separable state $|\psi\rangle = \frac{1}{\sqrt{2}}\left(|00\rangle + |01\rangle\right)$ in $\mathbb{C}$", fontsize=13, x=0.57, y=0.95)

ax.grid()
ax.legend(fontsize=12)
plt.show()

fig_filename = "./Results/" + state + "_" + dimension + "_" + type + "_" + nature + ".png"
fig.savefig(fig_filename, dpi=150, bbox_inches='tight')