In [None]:
from dataclasses import dataclass, field
import numpy as np
import math, cmath, functools
import itertools
import cvxpy as cp

In [None]:
# TODO: Lazy compute
@dataclass
class GraphState:
  n: int
  adj_mat: list[list[int]]
  phi: float = math.pi
  name: str = None
  _state_vector: np.ndarray = field(init=False, repr=False)
  _unitary: np.ndarray = field(init=False, repr=False)
  _basis: list[np.ndarray] = field(init=False, repr=False)
  _eigenvalues: list[float] = field(init=False, repr=False)
  _density_matrix: np.ndarray = field(init=False, repr=False)
  _stabilizers: np.ndarray = field(init=False, repr=False)
  _MIS2: np.ndarray = field(init=False, repr=False)

  def __post_init__(self):        

    self.adj_mat = np.array(self.adj_mat)

    # Check that adj_mat is n x n.
    if self.adj_mat.shape != (self.n, self.n):
      raise ValueError(f"adj_mat must be of shape ({self.n}, {self.n}), got {self.adj_mat.shape}.")
    
    # Check that the diagonal is zero.
    if not np.all(np.diag(self.adj_mat) == 0):
      raise ValueError("All diagonal elements of adj_mat must be zero.")
    
    # Check that the matrix is symmetric.
    if not np.array_equal(self.adj_mat, self.adj_mat.T):
      raise ValueError("adj_mat must be symmetric along the diagonal.")
    
    # Check that all entries are either 0 or 1.
    if not np.all((self.adj_mat == 0) | (self.adj_mat == 1)):
      raise ValueError("adj_mat must only contain zeros or ones.")

    self.recompute()

  def recompute(self):
    """
    Recompute the graph state's dependent attributes
    """
    self._unitary = self.graph_state_unitary()
    self._state_vector = self.graph_state()
    self._basis = self.graph_state_basis()
    self._density_matrix = np.outer(self._state_vector, self._state_vector.conj())
    self._eigenvalues = self.graph_state_eigenvalues()
    self._stabilizers = self.graph_state_stabilizers()
    self._MIS2 = self.graph_state_MIS2()

  def empty_graph_state_vector(self):
    """
    Returns |+ + + ... + >

    Input: number of qubits (self.n)
    Output: vector of length 2^n with all elements equal to 1/sqrt(2^n)
    """
    PLUS = (1/np.sqrt(2)) * np.array([1,1])
    tensor_product = functools.reduce(np.kron, [PLUS] * self.n)
    return tensor_product

  def graph_state_unitary(self):
    """
    Returns U s.t. |G> = U |+ + + ... + >

    Input: number of qubits (self.n), adjacency matrix of graph state (self.adj_mat)
    Output: unitary matrix of size 2^n x 2^n with nonzero terms only on the diagonal
    """
    N = 2**self.n  # Number of basis states on n qubits
    phases = np.ones(N, dtype=complex)
    # Explore the upper triangle of adj_mat and set diag(U)[state] *= -1 if both qubits i and j are in state |1>
    for i in range(self.n):
      for j in range(i+1,self.n):
        if self.adj_mat[i, j] != 0:
        # For every basis state |x>, check if both qubit i and qubit j are 1
          for state in range(N):
            # Use bitwise operations: (state >> i) & 1 extracts the i-th bit.
            if ((state >> i) & 1) and ((state >> j) & 1):
              phases[state] *= cmath.exp(1j * self.phi)
    return np.diag(phases)

  def graph_state(self):
    """
    Returns |G>

    Input: graph state unitary (self._unitary)
    Output: vector of length 2^n
    """
    U = self._unitary
    return U @ self.empty_graph_state_vector()

  def graph_state_basis(self):
    r"""
    Returns weighted graph state basis

    The set of states $|W\rangle = \sigma_Z^W |G\rangle$ forms a basis for $(\mathbb{C}^2)^v$. 
    Note that $\sigma_Z$ and $U_{ab}$ commute, and $\sigma_Z^W |+\rangle^V$ is an orthonormal basis of pure product states 

    Input: number of qubits (self.n) and graph state unitary (self._unitary)
    Output: list of 2^n vectors (each of length 2^n)
    """
    N = 2**self.n  # Number of basis states on n qubit
    empty_state = self.empty_graph_state_vector()
    U = self._unitary
    I = np.eye(2)
    Z = np.array([[1, 0],[0, -1]])
    
    basis_vectors = []
    for W in range(N):
      operators = []
      for i in range(self.n):
        if (W >> i) & 1:
          operators.append(Z)
        else:
          operators.append(I)
      pure_product_state = functools.reduce(np.kron, operators[::-1]) @ empty_state
      basis_vector = U @ pure_product_state

      basis_vectors.append(basis_vector)

    return basis_vectors

  def graph_state_eigenvalues(self):
    """
    Returns eigenvalues of graph state

    Input: graph state density matrix in computational basis (self._density_matrix)
    Output: list of eigenvalues
    """
    eigenvalues = []
    for b in self._basis:
        overlap = np.vdot(b, self._density_matrix @ b)
        eigenvalues.append(np.real_if_close(overlap))
    return eigenvalues

  def graph_state_stabilizers(self):
      if self.phi % math.pi != 0:
          return None

      Z = np.array([[1,0],[0,-1]])
      X = np.array([[0,1],[1,0]])
      I = np.array([[1,0],[0,1]])

      generators = []
      for vertex in range(self.n):
          operators = []
          vertex_relationships = self.adj_mat[vertex]
          for idx, relationship in enumerate(vertex_relationships):
              if idx == vertex: # Vertex itself
                  operators.append(X)
              elif relationship == 1: # Neighbor of vertex
                  operators.append(Z)
              else: # Not neighbor of vertex
                  operators.append(I)
          generator = functools.reduce(np.kron, operators[::-1])
          generators.append(generator)
              
      return generators

  '''
  Solve ILP for maximum independent set in the dist-1/2 adjacency matrix
  '''
  def graph_state_MIS2(self):
    A = self.adj_mat
    # (i,j) entry in A(G)^2 + A is the number of length <= 2 walks from vertex i, vertex j
    # https://math.stackexchange.com/questions/1507470/prove-that-i-j-entry-in-ax2-is-the-number-of-length-2-walk-from-verte
    A_dist2 = np.linalg.matrix_power(A, 2) + A

    A_dist2 = (A_dist2 > 0).astype(int) # Remove double counting of distance-1 and distance-2
    np.fill_diagonal(A_dist2, 0) # Clear self-loops
    
    x = cp.Variable(self.n, boolean=True) # x[i] is 1 if vertex i is in the independent set

    constraints = [x[i] + x[j] <= 1 for i in range(self.n) for j in range(i+1, self.n) if A_dist2[i, j]] # Can only pick i or j if A_dist2[i,j] = 1 (removed double counting earlier)

    objective = cp.Maximize(cp.sum(x))

    prob = cp.Problem(objective, constraints)
    prob.solve(solver=cp.MOSEK)

    ILP_solution = x.value.round().astype(int) # Numerical precision, solver should give solutions very close to 0 or 1
    MIS_solution = np.where(ILP_solution == 1)[0]
    return MIS_solution

  '''
  Compute fidelity w.r.t. another Graph state
  '''
  def fidelity_to(self, other) -> float:
    psi = self._state_vector
    if isinstance(other, GraphState):
      phi = other.state_vector
      overlap = np.vdot(phi, psi)  # <phi|psi>
      return float(np.abs(overlap)**2)
    raise TypeError("`other` must be GraphState.")
      

  def max_schmidt_coeff(self, squared: bool = True) -> float:
    """Max (squared) Schmidt coeff over all bipartitions of the n-qubit pure state."""
    tens = self._state_vector.reshape([2]*self.n)
    best = 0.0
    for k in range(1, self.n):  # consider all nontrivial bipartitions
        for A in itertools.combinations(range(self.n), k):
            B = tuple(i for i in range(self.n) if i not in A)
            M = np.transpose(tens, A + B).reshape(2**len(A), -1)
            s0 = np.linalg.svd(M, compute_uv=False, hermitian=False)[0]  # largest SV
            val = s0*s0 if squared else s0
            if val > best: best = float(val)
    return best

  @property
  def state_vector(self) -> np.ndarray:
    """Property to access the graph state |G>."""
    return self._state_vector

  @property
  def unitary(self) -> np.ndarray:
    """Property to access the graph state unitary U."""
    return self._unitary

  @property
  def basis(self) -> list[np.ndarray]:
    """Property to access the complete set of graph state basis vectors."""
    return self._basis

  @property
  def eigenvalues(self) -> list[float]:
    """Property to access the graph diagonal state eigenvalues."""
    return self._eigenvalues

  @property
  def density_matrix(self) -> np.ndarray:
    """Density matrix of the graph state |G><G|."""
    return self._density_matrix

  @property
  def stabilizers(self) -> np.ndarray:
    """Property to access the graph stabilizer generators"""
    return self._stabilizers

  @property
  def MIS2(self) -> np.ndarray:
    """Density matrix of the graph state |G><G|."""
    return self._MIS2