In [1]:
import datetime, os, sys, pickle
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt
from tqdm import tqdm
from lambeq import NumpyModel, AtomicType, IQPAnsatz
from lambeq.backend.quantum import Ry, Diagram, Bra
from util import data_loader
import tensornetwork as tn
from scipy.linalg import logm, sqrtm
from contextuality.model import CyclicScenario, Model
import seaborn as sns
from scipy.stats import gaussian_kde

In [2]:
# Not hermitian but normal
bases = {'a':Ry(0), 'A':Ry(np.pi/4), 'b':Ry(np.pi/8), 'B':Ry(3*np.pi/8)}
contexts = {'ab': np.kron(Ry(0).array, Ry(np.pi/8).array),
            'aB': np.kron(Ry(0).array, Ry(3*np.pi/8).array),
            'Ab': np.kron(Ry(np.pi/4).array, Ry(np.pi/8).array),
            'AB': np.kron(Ry(np.pi/4).array, Ry(3*np.pi/8).array)}

In [3]:
class QModel(NumpyModel):
    def __init__(self, use_jit: bool = False) -> None:
        super().__init__(use_jit)

    def get_output_state(self, diagrams):
        diagrams = self._fast_subs(diagrams, self.weights)
        results = []
        for d in diagrams:
            assert isinstance(d, Diagram)
            result = tn.contractors.auto(*d.to_tn()).tensor
            result = np.array(result).flatten()
            result = np.sqrt(result/sum(abs(result)))
            results.append(result)
        return np.array(results)

In [4]:
model_path = {'disjoint_uncut': 'runs/disjoint_uncut_130E/best_model.lt', 
              'disjoint_cut': 'runs/disjoint_cut_140E/best_model.lt',
              'spider_uncut': 'runs/spider_uncut_200E/best_model.lt', 
              'spider_cut': 'runs/spider_cut_50E/best_model.lt'}

diagram_path = {'disjoint_uncut': 'dataset/diagrams/disjoint_uncut.pkl',
                'disjoint_cut': 'dataset/diagrams/disjoint_cut.pkl',
                'spider_uncut': 'dataset/diagrams/spider_uncut.pkl',
                'spider_cut': 'dataset/diagrams/spider_cut.pkl'}

data_path = {'disjoint_uncut': 'dataset/contextuality_data/scenario442_disjoint_uncut.csv',
             'disjoint_cut': 'dataset/contextuality_data/scenario442_disjoint_cut.csv',
             'spider_uncut': 'dataset/contextuality_data/scenario422_spider_uncut.csv',
             'spider_cut': 'dataset/contextuality_data/scenario442_spider_cut.csv'}

In [294]:
def calc_violation(state: np.array) -> float:
    expectations = [(np.conjugate(state) @ (contexts[ops] @ state)) for ops in list(contexts.keys())]
    return max([sum(expectations) - 2*exp for exp in expectations])

In [295]:
def density_op(state, tol=1e-12):
    dense_mat = np.outer(state, np.conjugate(state))
    dense_mat.real[abs(dense_mat.real) < tol] = 0.0
    dense_mat.imag[abs(dense_mat.imag) < tol] = 0.0
    return dense_mat

In [296]:
def log_mat(mat): # Matrix logarithm via eigendecomposition
    evals, emat = np.linalg.eig(mat) # Get matrix V of eigenvectors of input matrix A
    emat_inv = np.linalg.inv(emat) # Get inverse of matrix V
    matp = emat @ mat @ emat_inv # Compute A' with a diagonal of eigenvalues tr(A') = evals
    tr = matp.diagonal() # Get the trace of the matrix
    np.fill_diagonal(matp, np.log2(tr, out=np.zeros_like(tr, dtype=np.complex128), where=(tr!=0))) # Element wise base 2 log of diagonal
    # Line above ignores log(0) error by replacing it with 0, which may lead to a wrong answer
    return emat_inv @ matp @ emat # Change basis back

In [297]:
def partial_trace(dense_mat):
    # Compute reduced density matrix for a bipartide quantum system
    dims_a = int(2**(np.floor(np.log2(dense_mat.shape[0])/2)))
    dims_b = int(2**(np.ceil(np.log2(dense_mat.shape[0])/2)))
    id_a = np.identity(dims_a)
    id_b = np.identity(dims_b)

    rho_a = np.zeros((dims_a, dims_a))
    rho_b = np.zeros((dims_b, dims_b))

    for base in id_b:
        bra = np.kron(id_a, base)
        ket = np.kron(id_a, base).T
        rho_a = rho_a + (bra @ dense_mat) @ ket
    for base in id_a:
        bra = np.kron(id_b, base)
        ket = np.kron(id_b, base).T
        rho_b = rho_b + (bra @ dense_mat) @ ket        
    return rho_a, rho_b

In [298]:
def calc_vne(dense_mat, direct=True):
    if direct:
        ent =  -np.trace(dense_mat @ log_mat(dense_mat))
    else:
        evals = np.linalg.eigvals(dense_mat)
        evals = evals[np.abs(evals) > 1e-12]
        ent = -np.sum(evals * np.log2(evals))
    ent = ent.round(12)
    return ent

In [299]:
def qrel_ent(mat1, mat2):
    return np.trace(mat1 @ (log_mat(mat1) - log_mat(mat2)))

In [23]:
data_sc = data_loader(scenario=CyclicScenario(['a','b','A','B'],2), model_path=model_path['spider_cut'])
data_sc.get_data(data_path['spider_cut'])
data_sc.get_diagrams(diagram_path['spider_cut'])

data_su = data_loader(scenario=CyclicScenario(['a','b','A','B'],2), model_path=model_path['spider_uncut'])
data_su.get_data(data_path['spider_uncut'])
data_su.get_diagrams(diagram_path['spider_uncut'])

In [24]:
data_dc = data_loader(scenario=CyclicScenario(['a','b','A','B'],2), model_path=model_path['disjoint_cut'])
data_dc.get_data(data_path['disjoint_cut'])
data_dc.get_diagrams(diagram_path['disjoint_cut'])

data_du = data_loader(scenario=CyclicScenario(['a','b','A','B'],2), model_path=model_path['disjoint_uncut'])
data_du.get_data(data_path['disjoint_uncut'])
data_du.get_diagrams(diagram_path['disjoint_uncut'])

In [254]:
eoe = []
qre = []
qm = QModel.from_checkpoint(model_path['spider_cut'])
for diagram in tqdm(data.diagrams):
    try:
        state = qm.get_output_state([diagram])[0]
        dense_mat = density_op(state)
        rho_a, rho_b = partial_trace(dense_mat)
        eoe.append(calc_vne(rho_a, direct=False))

        diag1 = diagram.apply_gate(Bra(0),0)
        diag2 = diagram.apply_gate(Bra(0),1)
        state1 = qm.get_output_state([diag1])[0]
        state2 = qm.get_output_state([diag2])[0]
        mat1 = density_op(state1)
        mat2 = density_op(state2)
        qre.append(qrel_ent(mat1, mat2).round(6))

    except Exception as err:
        tqdm.write(f"Error: {err}".strip(), file=sys.stderr)

Error: 'Unknown symbol: storm__n@n.l_0'                                                                                                                               
Error: 'Unknown symbol: the__n.r@n@n.l_0'                                                                                                                             
Error: 'Unknown symbol: storm__n@n.l_0'                                                                                                                               
Error: 'Unknown symbol: storm__n@n.l_0'                                                                                                                               
Error: 'Unknown symbol: storm__n@n.l_0'                                                                                                                               
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3273/3273 [05:33<00:00,  9.81it/s

In [343]:
def plot_cnxt(x, y, z, title: str, xl: str, yl: str, save: bool=True) -> None:
    cmap = plt.get_cmap('viridis_r')
    cmap.set_under('red')
    xy = np.vstack([x,y])
    c = gaussian_kde(xy)(xy)
    scat = plt.scatter(x=x, y=y, c=c, cmap=cmap)
    plt.axvline(x=1/6, color='r', linestyle='-')
    plt.axhline(y=1, color='r', linestyle='-')
    #plt.text(x=1/6+0.05,y=5,s='Sheaf Contextual')
    #plt.text(x=0.7,y=1.5,s='CbD Contextual')
    plt.xlabel(xl)
    plt.ylabel(yl)
    plt.colorbar(label='Signalling Fraction')
    plt.title(title)
    scat.set_alpha(0.5)
    #scat.cmap.set_over('red')
    if save:
        plt.savefig('figures/' + title + '_' + datetime.datetime.now().strftime("%Y-%m-%d_%H_%M_%S"))
    plt.show()

In [5]:
wrong_ref_spider = data_loader(scenario=CyclicScenario(['a','b','A','B'],2))
wrong_ref_disjoint = data_loader(scenario=CyclicScenario(['a','b','A','B'],2))
right_ref_spider = data_loader(scenario=CyclicScenario(['a','b','A','B'],2))
right_ref_disjoint = data_loader(scenario=CyclicScenario(['a','b','A','B'],2))

In [6]:
wrong_ref_disjoint.get_diagrams('dataset/diags_wrong_ref.pkl')

In [None]:
wrong_ref.update_model('runs/spider_uncut_200E/best_model.lt')
wrong_ref.gen_data()

Error: 'Unknown symbol: storm__n@n.l_0'                                                                                                                               
Error: 'Unknown symbol: the__n.r@n@n.l_0'                                                                                                                             
 25%|███████████████████████████████▏                                                                                              | 765/3086 [01:01<02:55, 13.21it/s]

In [None]:
wrong_ref.update_model('runs/disjoint_uncut_130E/best_model.lt')
wrong_ref.gen_data()

In [8]:
state = np.array([np.sqrt(1/2), 0, 0, np.sqrt(1/2)])

In [9]:
pr_dist = np.array([np.square(contexts['ab']@state),
                    np.square(contexts['aB']@state),
                    np.square(contexts['Ab']@state),
                    np.square(contexts['AB']@state)])