In [155]:
import torch
import numpy as np
def tensor_to_mps(psi, leg_dim, max_bond_dim, Print=False):
    #leg_dim의 homogeneous 가정함
    a=np.log2(psi.numel()).item()/np.log2(leg_dim).item()
    if a%1 != 0 :
        print('input not valid')
    N = int(a)  # 원래 텐서의 크기
    psi0=psi
    mps_tensors = []
    bond_dim=1
    for n in range(N-1):
        psi = psi.reshape(leg_dim*bond_dim,-1)
        d=bond_dim
        U, S, V = torch.linalg.svd(psi, full_matrices=False)
        bond_dim = min(S.size(0), max_bond_dim)
        U = U[:, :bond_dim]
        S = torch.diag(S[:bond_dim])
        S=torch.tensor(S, dtype=torch.cfloat)
        V = V[:bond_dim,:]
        if n==0:
            mps_tensors.append(U.reshape(1,leg_dim,-1))
        else:
            mps_tensors.append(U.reshape(d,leg_dim,-1))
        psi = torch.matmul(S, V)
    mps_tensors.append(psi.reshape(-1, leg_dim,1))
    #return type : [ (1*leg*leg) (leg*leg*bond) (bond*leg*bond) ..... (bond*leg*1)]
    if Print:
        PrintMPS(mps_tensors)
        result = contract_mps(mps_tensors)
        print('loss of MPS(%) : ',100*torch.sum(torch.abs(result.flatten() - psi0.flatten()))/torch.sum(torch.abs(psi0.flatten())).item())
        print('###############################################')
    return mps_tensors
def contract_mps(mps_tensors):
    #leg_dim의 homogeneous 가정함
    # MPS 텐서 목록이 비어 있으면 에러를 발생시킵니다.
    if len(mps_tensors) == 0:
        raise ValueError("MPS tensor list cannot be empty")
    leg_dim=mps_tensors[0].size(1)
    # 첫 번째 텐서를 초기 결과로 설정합니다.
    result = mps_tensors[0]

    # 나머지 텐서를 순차적으로 contract합니다.
    for i in range(1, len(mps_tensors)):
        result = torch.einsum("ilj, jkm->ilkm", result, mps_tensors[i]).reshape(1,leg_dim**(i+1),-1)
    # return type : 1 * (leg_dim)^n * 1
    return result
def PrintMPS(mps):
    print('###############################################')
    print('printing MPS vector')
    print('###############################################')
    for idx, tensor in enumerate(mps):
        print(f"MPS tensor {idx}: {tensor.shape}")
    print('###############################################')

In [156]:
# 양자 상태 벡터 생성 (예: 크기가 2^4인 무작위 복소수 벡터)
psi = torch.rand(3**4) + 1j * torch.rand(3**4)

# 양자 상태 벡터를 MPS로 변환
mps_tensors = tensor_to_mps(psi, 3, max_bond_dim=4,Print=True)

# MPS 텐서 출력
#PrintMPS(mps_tensors)


# 예시 MPS 텐서 목록
mps_tensors = [
    torch.tensor([[[1, 2], [3, 4]]], dtype=torch.float32),
    torch.tensor([[[5, 6], [7, 8]]], dtype=torch.float32),
    torch.tensor([[[9, 10], [11, 12]]], dtype=torch.float32)
]

# MPS 텐서들을 contract합니다.
result = contract_mps(tensor_to_mps(psi, 3, max_bond_dim=7))
#print('loss of MPS')
#print(torch.sum(torch.abs(result.flatten() - psi.flatten()))/torch.sum(torch.abs(psi.flatten())))

###############################################
printing MPS vector
###############################################
MPS tensor 0: torch.Size([1, 3, 3])
MPS tensor 1: torch.Size([3, 3, 4])
MPS tensor 2: torch.Size([4, 3, 3])
MPS tensor 3: torch.Size([3, 3, 1])
###############################################
loss of MPS(%) :  tensor(21.6435)
###############################################


  S=torch.tensor(S, dtype=torch.cfloat)
