author: Marcin Płodzień 

e-mail: marcin.plodzien@icfo.eu; mplodzien@gmail.com

www: https://github.com/MarcinPlodzien, https://sites.google.com/site/marcinplodzienphysics/

# Quantum State Tomography with Classical Shadows. Implementation from Scratch

In this notebook, we introduce the problem of reconstucting quantum state from measurement, known as Quantum State Tomography (QST), and we briefly describe the state-of-the-art QST protocol known as Classical Shadows, introduced by Hsin-Yuan Huang, Richard Kueng, John Preskill [[1]](https://www.nature.com/articles/s41567-020-0932-7).

Finally, we present the implementation of classical shadows QST from scratch with in Python (with pytorch).
The aim of this notebook is to give a hands-on experience with the classical shadows protocol for QST.

## 1. Quantum State Tomography

Quantum state tomography (QST) lies at the heart of quantum information theory, serving as a fundamental tool for characterizing and understanding quantum systems. 

QST is a crucial technique for characterizing and reconstructing quantum states by performing projective measurements on an informationally complete basis. However, the complexity of state reconstruction grows exponentially with the number of system constituents, such as spins or qubits. With the growing number of qubits in current quantum technologies, scalable QST schemes are in high demand. Various QST methods have been proposed to address this issue, such as QST for sparse quantum states by [compressive sensing](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.105.150401), or for [permutationally invariant quantum states](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.105.250403). However, the exponential growth of the required measurements in the abovementioned method makes them impractical for a more significant number of qubits.  

To overcome the scaling concerns, the [classical shadows technique](https://www.nature.com/articles/s41567-020-0932-7) has been proposed. This method uses a "shadow" observables set to obtain an approximate density matrix that matches the exact density matrix within some error bound, significantly reducing the number of measurements required for QST and providing an efficient way for reconstructing density matrices for large quantum systems. By using this technique, one can obtain an approximate density matrix with much fewer measurements than required by standard QST methods.

Side comment:
In the Noisy Intermediate-Scale Quantum (NISQ) era devices, the problem of density matrix reconstruction can be tackled in the language of [low-depth Variational Quantum Circuits (VQC) ansatzes](https://journals.aps.org/pra/abstract/10.1103/PhysRevA.101.052316). With the help of classical machine learning optimization techniques, an optimal set of quantum circuit parameters can be found to maximize the fidelity of the VQC with the target highly-entangled quantum state. As a result, a many-body Bell-correlated state can be stored within a quantum gates data structure. While the number of parameters of such a circuit grows polynomially with the number of qubits, only a  polynomial number of measurements for quantum state tomography is required.

## 2. Classical Shadows with randomized measurements protocol
We consider the chain of $N$ spins-$1/2$ described by a density matrix  $\hat{\varrho}$ decomposed in the computational basis
\begin{equation}
\{|\vec{s}\rangle\}= \{ |s_1, s_2,s_3,\dots,s_N\rangle\} = \{ \bigotimes_{j=1}^{N}|s_j\rangle\} \}$, 
\end{equation}
with $s_j = \pm 1$.

Classical shadows tomography aims to reconstruct the target quantum state $\hat{\varrho}$ based on $M$ measurements prepared on its identical copies. In each $m$-th measurement, a random unitary 
$\hat{U}_{m} = \bigotimes_{j = 1}^{N}\hat{u}_j^{(m)}$ is applied to the target state 
\begin{equation}
    \hat{\varrho}_{m} = \hat{U}_{m}\hat{\varrho}\hat{U}_m^\dagger,
\end{equation}
where $\hat{u}^{(m)}_j$ are random operators chosen from some ensemble ${\cal U}$. Next, after the projective measurement in a computational basis
we obtain a bit-string $\{s_1^{(m)},\dots,s_N^{(m)}\}$ and construct the classical shadow of the initial state as
\begin{equation}
    \hat{\rho}_{m}  = {\cal M}^{-1}\bigg[\bigotimes_{j=1}^{N} \hat{u}_j^{(m)\dagger} |s_j^{(m)}\rangle\langle s_j^{(m)}|\hat{u}_j^{(m)}\bigg],
\end{equation}
where the inverse map ${\cal M}^{-1}$ is determined by ${\cal U}$ [[1]](https://www.nature.com/articles/s41567-020-0932-7). 

Here, we assume ${\cal U}$ to be a Pauli measurements group (group of matrices which are transformation to eigenbasis of the operators $X,Y,Z$ at given spin $j$, i.e. ${\cal U} = \{ \hat{H}, \hat{H}\hat{S}^\dagger, Z \}$ respectively for $X$, $Y$, and $Z$), for which the inverse map factorizes, 
${\cal M}^{-1} = \bigotimes_{j=1}^{N}{\cal M}_1^{-1}$, where ${\cal M}^{-1}_n[\cdot] = (2^n+1)[\cdot] - \mathbb{1}_{2^n}{\rm Tr}([\cdot])$, and the $m$-th classical shadow reads  [[1]](https://www.nature.com/articles/s41567-020-0932-7)
\begin{equation}
    \hat{\rho}_{m} = \bigotimes_{j=1}^{N} \big[3 \hat{u}_j^{(m)\dagger} |s_j^{(m)}\rangle\langle s_j^{(m)}|\hat{u}_j^{(m)} - \mathbb{1}_{2}\big].
\end{equation}
After $M$ realizations of classical shadows the reconstructed density matrix $\hat{\varrho}^*$ is given by 
\begin{equation}
    \hat{\varrho}^{*} = \frac{1}{M}\sum_{m=1}^M\hat{\rho}_{m},
\end{equation}

In other words, for a $m$-th copy of the target state $\hat{\rho}$, each spin in the chain is rotated to eigenbasis of the of the randomly chosen operator $X$, $Y$, or $Z$. Next, the projective measurement is prepared and we collect the resulting bit-string, i.e. vector of ${s_1, s_2, \ldots, s_L}$, where $s_i = \pm 1$.
Next, we prepare inverse operation prepared on the each spin indiviudally, i.e. we unrotate it to the initial basis, and in the final step we prepared the $m$-th reconstruction via the inverse map ${\cal M}^{-1}$ and obtain $\hat{\rho}_{m}$. The reconstructed density matrix is given as average over many $\hat{\rho}_{m}$ realizations.


The power of classical shadows tomography lies in cheap post-processing requiring only $2NM$ numbers, which must be stored in a classical memory, i.e., $M$ instances of information about $N$ randomly chosen operators and accompanying bit strings of length $N$.

## 3. Implementation

### 3.1 Exact representation of spin-1/2 operators on a 1D chain

The starting point is to implement the Pauli spin chain operators on a spin-$1/2$ chain with $L$ spins.
The Pauli spin operators $X_i,Y_i,Z_i$ acting on $i$-th spin are defined as:
\begin{equation}
 \begin{split}
 X_i  &= \mathbb{1}_1\otimes\dots\mathbb{1}_{i-1}\otimes\hat{\sigma}^x\otimes\mathbb{1}_{i+1}\dots\mathbb{1}_{L},\\
 Y_i  &= \mathbb{1}_1\otimes\dots\mathbb{1}_{i-1}\otimes\hat{\sigma}^y\otimes\mathbb{1}_{i+1}\dots\mathbb{1}_{L},\\
 Z_i  &= \mathbb{1}_1\otimes\dots\mathbb{1}_{i-1}\otimes\hat{\sigma}^z\otimes\mathbb{1}_{i+1}\dots\mathbb{1}_{L},
 \end{split}
\end{equation}
where $\hat{\sigma}^{x,y,z}$ are $2\times2$ Pauli operators.

In a similar manner we can define Hadamard matrix $H_i$ acting on $i$-th spin, i.e.
\begin{equation}
H_i = \mathbb{1}_1\otimes\dots\mathbb{1}_{i-1}\otimes H\otimes\mathbb{1}_{i+1}\dots\mathbb{1}_{L},
\end{equation}
where $H = \frac{\hat{\sigma}^x + \hat{\sigma}^z}{\sqrt{2}}$.

In [1]:
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch as pt
from torch import matrix_exp as expm
from torch.linalg import eigh as eigh
import numpy as np

In [2]:
id_local = pt.tensor([[1.,0],[0,1.]]) + 0*1j
sigma_x = pt.tensor([[0,1.],[1.,0]]) + 0*1j
sigma_y = 1j*pt.tensor([[0,-1.],[1.,0]]) + 0*1j
sigma_z = pt.tensor([[1.,0],[0,-1.]]) + 0*1j
hadamard = 1.0/pt.sqrt(pt.tensor(2))*(sigma_x + sigma_z)+1j*0    
s = pt.tensor([[1,0],[0,1j]])+1j*0
s_dagger = pt.tensor([[1,0],[0,-1j]])+1j*0

M_u = pt.tensor([[1,0],[0,0]]) + 0*1j # projector on spin-up, or |1>
M_d = pt.tensor([[0,0],[0,1]]) + 0*1j # projector on spin-down, or |-1>

def get_Identity(k):  # returns k-tensor product of the identity operator, ie. Id^k
    Id = id_local
    for i in range(0, k-1):
        Id = pt.kron(Id, id_local)
    return Id
       
def get_chain_operator(A, L, i):
    Op = A
    if(i == 1):
        Op = pt.kron(A,get_Identity(L-1))
        return Op
    if(i == L):
        Op = pt.kron(get_Identity(L-1),A)
        return Op
    if(i>0 and i<L):
        Op = pt.kron(get_Identity(i-1), pt.kron(Op, get_Identity(L-i)))
        return Op

def get_chain_operators(L):
    if(L>1):
      Id = get_chain_operator(id_local, L, 1)
      X = {}
      Y = {}
      Z = {}
      H = {}
      S = {}
      S_dagger = {}

      for qubit_i in range(1, L+1):    # Loop over indices on a chain
          X[qubit_i] = get_chain_operator(sigma_x, L, qubit_i)      
          Y[qubit_i] = get_chain_operator(sigma_y, L, qubit_i)      
          Z[qubit_i] = get_chain_operator(sigma_z, L, qubit_i)      
          H[qubit_i] = get_chain_operator(hadamard, L, qubit_i)
          S[qubit_i] = get_chain_operator(s, L, qubit_i)
          S_dagger[qubit_i] = get_chain_operator(s_dagger, L, qubit_i)
      return Id, X, Y, Z, H, S, S_dagger
    else:
      return id_local, sigma_x, sigma_y, sigma_z, hadamard, s, s_dagger

Operators are expressed in the diagonal basis of the $Z = \sum_{i=1}^{L} Z_i$ operator

\begin{equation}
 \begin{split}
   |v_1\rangle & = |\uparrow \uparrow \dots \uparrow \rangle \\
   |v_2\rangle & = |\uparrow \uparrow \dots \downarrow \rangle \\
   & \vdots \\
   |v_D\rangle & = |\downarrow \downarrow \dots \downarrow \rangle \\
 \end{split}
\end{equation}

\begin{equation}
\begin{split}
   X & = \sum_{k,l} <v_k|X|v_l>|v_k\rangle\langle v_l| \\
   Y & = \sum_{k,l} <v_k|Y|v_l>|v_k\rangle\langle v_l| \\
   Z & = \sum_{k,l} <v_k|Z|v_l>|v_k\rangle\langle v_l| \delta_{k,l},
\end{split}
\end{equation}

In the following we denote: $\uparrow \equiv 1$, $\downarrow \equiv -1$.


Let's us construct our spin-$1/2$ chain Hilbert space:

In [3]:
L = 3
D = 2**L
Id, X, Y, Z, H, S, S_dagger = get_chain_operators(L)

def get_spin_basis(L):
    D = 2**L
    basis = pt.zeros((D,L)) + 0*1j
    for v_i in range(0,D):
        fock_state = pt.zeros(D) + 0*1j
        fock_state[v_i] = 1
        for i in range(1,L+1):
            tmp = pt.vdot(fock_state, Z[i]@fock_state)
            basis[v_i,i-1] = tmp.real
    
    print("Fock basis: ")
    for v_i in range(0,D):  
        string_fock_vector = "|v_" + "{:03d}".format(v_i) + "> = |"
        for i in range(1,L+1):
            tmp = int(basis[v_i,i-1].item().real)
            if(tmp==1):
                string_plus_minus = " {:1d}".format(tmp)
            if(tmp==-1):
                string_plus_minus = "{:1d}".format(tmp)
                
             
            string_fock_vector = string_fock_vector + string_plus_minus + " "
        string_fock_vector = string_fock_vector + ">"
        print(string_fock_vector)
    return basis

basis = get_spin_basis(L)

Fock basis: 
|v_000> = | 1  1  1 >
|v_001> = | 1  1 -1 >
|v_002> = | 1 -1  1 >
|v_003> = | 1 -1 -1 >
|v_004> = |-1  1  1 >
|v_005> = |-1  1 -1 >
|v_006> = |-1 -1  1 >
|v_007> = |-1 -1 -1 >


Helper functions

In [27]:
def sqrt_matrix(A):
  L, Q = pt.linalg.eig(A)
  
  return Q @ pt.diag(pt.sqrt(L )) @ pt.linalg.inv(Q)

def get_trace_distance(rho, sigma):
  tmp = (rho-sigma).conj().T@(rho-sigma)
  return 0.5*pt.trace( sqrt_matrix(tmp)).real

def print_density_matrix(rho, string_title, print_eigenvalues):
  print("="*40)
  print(string_title)
  print("Norm:   Tr[rho] = {:2.3f}".format(pt.trace(rho).item().real))
  print("Purity: Tr[rho^2] = {:2.3f}".format(pt.trace(rho@rho).item().real))
  lambda_i, ket_lambda_i = pt.linalg.eig(rho)
  lambda_i = lambda_i.real

  string_evals = ""
  for i in range(0,D):
     string_evals = string_evals + " lambda_" + "{:03d}".format(i) + "  = {:2.3f}".format(lambda_i[i].item()) + "\n"
  if(print_eigenvalues  == True):
     print("Eigenvalues :")
  print(string_evals)
  
  print("Real part:")
  string_columns = "     "
  for i in range(0,D):
    string_columns = string_columns + "      " + "|{:03d}".format(i)+">"
  print(string_columns)
  for i in range(0,D):
    string_row = "|" + "{:03d}".format(i)+">"
    for j in range(0,D):
      string_row = string_row + "   " + "   {:2.3f}".format(rho[i,j].item().real)
    print(string_row)
  print("\n")

  print("Imaginary part:")
  string_columns = "     "
  for i in range(0,D):
    string_columns = string_columns + "      " + "|{:03d}".format(i)+">"
  print(string_columns)
  for i in range(0,D):
    string_row = "|" + "{:03d}".format(i)+">"
    for j in range(0,D):
      
      string_row = string_row + "   " + "  {: 2.3f}".format(rho[i,j].item().imag)
    print(string_row)
  print("\n")


  return

### 3.2 Probing density probability
To probe a given discrete density probability $\vec{p} = \{p_0, p_1, \ldots,p_D\}$, $\sum_v p_v = 1$, we choose a random number $r \in [0,1]$, and we find such $v_{\rm max}$ for which the cumulative probability $p_c = \sum_{v=1}^{v_{\rm max}}$ fullfils $p_c < r$  - in such a way, we probe the $v_{\rm max}$ from a given $\vec{p}$.

### 3.3 Simulating projective measurement on the quantum state $\hat{\rho}$

In the projective measurement protocol (in the computational basis) the state of the system $\hat{\rho}$ colapses to one of its basis vectors, i.e. $\hat{\rho} \to |v\rangle\langle v|$, with a probability given by the density probability $\vec{p} =\{p_0, p_1, \ldots, p_D\} = {\rm diag}(\hat{\rho})$.

To simulate collaps of the wave-function to the one of its basis vectors after a projective measurement, we have to implement probing the density probability $\vec{p}$ of the considered quantum state. We probe basis $|v\rangle$ given the probability distribution $\vec{p}$;   $|v\rangle = |s_1, s_2, \ldots, s_L\rangle$, where $s_p = \pm 1$. We denote $\vec{s} = \{s_1,\ldots,s_L\}$ as a "bit-string".

In [5]:
def prepare_measurement_global_(rho, verbose = False): #we cast total wave-function onto single Fock vector
  r = np.random.rand(1)[0]
  p_c = 0 # cumulative probability
  for v in range(0,D):                           
    p_c = p_c + rho[v,v].real
    if(p_c >= r):
      break
  psi_measured = pt.zeros(D) + 0*1j # projective measurement
  psi_measured[v] = 1      
  bit_string = basis[v,:]
  if verbose:
   print("== prepare measurement ==")      
   print("     Wave-function probability distribution |psi_v|^2: ", pt.diag(rho).real)
   print("     Probe probability distribution:")
   print("     random number: r = ", "{:2.2f}".format(r), " | cummulant probability = ", "{:2.2f}".format(cp))
   print("     wave function collapsed to bit string -> |v=" + str(v) + "> = ",bit_string.real)
   print("     which corresponds to the basis vector:", psi_measured.real)
  return psi_measured, bit_string.real.to(pt.int).tolist()

In [6]:
def prepare_QST_classical_shadows_(rho_target, N_shadows, verbose):

    rho_shadow_global_measurement = pt.zeros((D,D)) + 0*1j
  
    for i in tqdm(range(0,N_shadows)):   
      U_i = np.random.choice([1,2,3], L, replace=True)
    
      rho_rotated = rho_target.clone()
      for j_spin in range(1,L+1):
        if(U_i[j_spin-1] == 1):
            U_i_j = H[j_spin]
    
        if(U_i[j_spin-1] == 2):    
            U_i_j = H[j_spin]@S_dagger[j_spin]
     
        if(U_i[j_spin-1] == 3):
            U_i_j = Z[j_spin]      
    
        rho_rotated = U_i_j@rho_rotated@U_i_j.conj().T

      rho_after_measurement, b_i_global_measurement  = prepare_measurement_global_(rho_rotated, verbose)
      s_i_global_measurement = pt.tensor([1])
     
      for j_spin in range(1,L+1):
        if(U_i[j_spin-1] == 1):
            V_i_j = hadamard
     
        if(U_i[j_spin-1] == 2):    
            V_i_j = hadamard@s_dagger
       
        if(U_i[j_spin-1] == 3):
            V_i_j = sigma_z
     
        # Inverse channel after global measurement
    
        if(b_i_global_measurement[j_spin-1] == 1):
          rho_j_global_measurement = M_u
    
        if(b_i_global_measurement[j_spin-1] == -1):
          rho_j_global_measurement = M_d
    
    
        s_i_global_measurement =  pt.kron(3*V_i_j.conj().T@rho_j_global_measurement@V_i_j - id_local, s_i_global_measurement)
    
      rho_shadow_global_measurement = rho_shadow_global_measurement + s_i_global_measurement
    
    rho_QST  = rho_shadow_global_measurement/N_shadows
    return rho_QST

As an example let us consider reconstruction of the quantum system in a $L$ qubit $GHZ$ state, i.e.

\begin{equation}
|{\rm GHZ}\rangle = \frac{|0\ldots0\rangle + |1\ldots1\rangle}{\sqrt{2}}
\end{equation}

In [21]:
psi_target =  pt.zeros(D) + 0*1j # GHZ state
psi_target[0] = 1
psi_target[-1] = 1
norm = pt.sum(pt.abs(psi_target)**2)
psi_target = psi_target/pt.sqrt(norm)
rho_target = pt.outer(psi_target.conj(),psi_target)

In [29]:
rho_qst = prepare_QST_classical_shadows_(rho_target, 200000, verbose = False)

100%|█████████████████████████████████| 200000/200000 [00:40<00:00, 4911.88it/s]


In [30]:
print_density_matrix(rho_target,"Target density matrix:",True)
print_density_matrix(rho_qst,"Reconstructed density matrix:",True)

Target density matrix:
Norm:   Tr[rho] = 1.000
Purity: Tr[rho^2] = 1.000
Eigenvalues :
 lambda_000  = 1.000
 lambda_001  = -0.000
 lambda_002  = 0.000
 lambda_003  = 0.000
 lambda_004  = 0.000
 lambda_005  = 0.000
 lambda_006  = 0.000
 lambda_007  = 0.000

Real part:
           |000>      |001>      |002>      |003>      |004>      |005>      |006>      |007>
|000>      0.500      0.000      0.000      0.000      0.000      0.000      0.000      0.500
|001>      0.000      0.000      0.000      0.000      0.000      0.000      0.000      0.000
|002>      0.000      0.000      0.000      0.000      0.000      0.000      0.000      0.000
|003>      0.000      0.000      0.000      0.000      0.000      0.000      0.000      0.000
|004>      0.000      0.000      0.000      0.000      0.000      0.000      0.000      0.000
|005>      0.000      0.000      0.000      0.000      0.000      0.000      0.000      0.000
|006>      0.000      0.000      0.000      0.000      0.000      0.000   

In [31]:
trace_distance = get_trace_distance(rho_target, rho_qst)
print("Trace distance between target and reconstructed state: {:2.2f}".format(trace_distance))

Trace distance between target and reconstructed state: 0.03


As we can see, the trace distance between reconstructed and target density matrix is close to $0$; increasing number of classical shadows shots will improve the reconstruction quality.