In [27]:
import numpy as np
import torch
import opt_einsum
import itertools

import process_matrix
import utils

from tqdm import tqdm

torch.cuda.empty_cache()

import importlib
importlib.reload(utils)
importlib.reload(process_matrix)

<module 'process_matrix' from '/n/home10/jmcgreivy/convex-quantum-error-correction/process_matrix.py'>

In [28]:
q_s = 1
q_c = 2
device = "cuda"

In [29]:
# Generalized Amplitude Damping Krauss Operators
N = 0
g = 0

K_1 = torch.tensor([[np.sqrt(1 - N), 0],[0, np.sqrt(1 - N) * np.sqrt(1 - g)]], device = device)
K_2 = torch.tensor([[0,np.sqrt(g*(1-N))],[0,0]], device = device)
K_3 = torch.tensor([[np.sqrt(N)*np.sqrt(1-g), 0],[0,np.sqrt(N)]], device = device)
K_4 = torch.tensor([[0,0],[np.sqrt(g * N), 0]], device = device)
K = [K_1, K_2, K_3, K_4]

X_E = utils.krauss_to_X_E(K, q_c)

In [30]:
X_C = process_matrix.ProcessMatrix(q_s, q_c, device="cuda")
X_R = process_matrix.ProcessMatrix(q_c, q_s, device="cuda")

In [31]:

optimizer_C = torch.optim.SGD(X_C.parameters(), lr = 0.2)
optimizer_R = torch.optim.SGD(X_R.parameters(), lr = 0.2)


In [32]:
regularization = 1

pbar = tqdm(range(1000))
for epoch in pbar:    
    #Optimize X_C:
    W_C = (1 / (X_C().shape[0]**2)) * opt_einsum.contract("iljg,lmgs->misj", X_R().detach(), X_E)
    for _ in range(50):
        optimizer_C.zero_grad()
        
        f_avg = opt_einsum.contract("misj,misj->", X_C(), W_C).real
        X_C_identity = utils.sums_to_identity(X_C())
        X_C_PSD = utils.positive_eigenvalues(X_C())
        l = -f_avg + regularization*(X_C_identity + X_C_PSD)
        l.backward()
            
        optimizer_C.step()

        with torch.no_grad():
            X_C().data = utils.make_PSD(X_C().data)
            X_C().data = utils.make_sum_to_identity(X_C().data)
    
    #Optimize X_R
    W_R = (1 / (X_C().shape[0]**2)) * opt_einsum.contract("misj,lmgs->iljg", X_C().detach(), X_E)
    for _ in range(50):        
        # X_R Optimizing
        optimizer_R.zero_grad()
        
        f_avg = opt_einsum.contract("iljg,iljg->", X_R(), W_R).real
        X_R_identity = utils.sums_to_identity(X_R())
        X_R_PSD = utils.positive_eigenvalues(X_R())
        l = -f_avg + regularization*(X_R_identity + X_R_PSD)
        l.backward()
        
        optimizer_R.step()

        with torch.no_grad():
            X_R().data = utils.make_PSD(X_R().data)
            X_R().data = utils.make_sum_to_identity(X_R().data)
    
    description = f"Avg Fidelity : {f_avg}"
    pbar.set_description(description)

Avg Fidelity : 0.17436531305977207:  11%|█         | 108/1000 [00:38<05:20,  2.78it/s]


KeyboardInterrupt: 