# The Schmidt Decomposition
---

### Definition

Any pure bipartite state can be expressed as: 

$$ 
|\zeta \rangle_{AB} = \sum_{i=0}^{N-1}\sum_{j=0}^{M-1} c_{ij} |\psi_i \rangle_{A} \otimes |\phi_j \rangle_{B} 
$$

For example:

$$ |\zeta \rangle_{AB} = \frac{1}{2} |0\rangle_A |0\rangle_B - 
                         \frac{1}{2} |0\rangle_A |1\rangle_B + 
                         \frac{1}{2} |1\rangle_A |0\rangle_B - 
                         \frac{1}{2} |1\rangle_A |1\rangle_B  $$

The Schmidt Decomposition for such state is given by:

$$ |\zeta \rangle_{AB} = \sum_{k=0}^{d-1} \lambda_{k} |u_k \rangle_{A} \otimes |v_k \rangle_{B}, $$

- $ |u_k \rangle_{A} $ and $|v_k \rangle_{B}$ are the Schmidt vectors: orthonormal states in the Hilbert space of $A$ and $B$, respectively. 

- $\lambda_k$ are the Schmidt coefficients: $\lambda_k \in  \mathbb{R}_{\geq 0}$

- $d$ is the total number of Schmidt coefficients: $d$ = $\text{min}(N,M)$

For the example above, its Schmidt decomposition is:

$$ 
|\zeta \rangle_{AB} = \frac{1}{\sqrt{2}} \bigg( |0\rangle_A + |1\rangle_A \bigg) \otimes 
                         \frac{1}{\sqrt{2}} \bigg( |0\rangle_B - |1\rangle_B \bigg) 
                         = 1 |+\rangle_A |-\rangle_B + 0 |-\rangle_A |+\rangle_B
$$

### Utility

The Schmidt Decomposition allows us to easily asses if a state is separable or entangled.

This is done using the Schmidt number (Schmidt rank) $r_s$: the total number of Schmidt coefficients different than zero.

- $ 1 \leq r_s \leq d $
- If $r_s = 1$, the state is separable
- If $r_s > 1 $, the state is entangled

For example:
$$ 
\begin{aligned}
|\zeta \rangle_{AB} &= \frac{1}{2} |0\rangle_A |0\rangle_B - 
                         \frac{1}{2} |0\rangle_A |1\rangle_B + 
                         \frac{1}{2} |1\rangle_A |0\rangle_B - 
                         \frac{1}{2} |1\rangle_A |1\rangle_B    
\\
\\
&= 1 |+\rangle_A |-\rangle_B + 0 |-\rangle_A |+\rangle_B 
\end{aligned}                         
$$ 

here, $r_s = 1$, so state is separable

$$
\begin{aligned}
|\chi \rangle_{AB} &= \frac{1}{2} |0\rangle_A |0\rangle_B \color{orange}{+} 
                         \frac{1}{2} |0\rangle_A |1\rangle_B + 
                         \frac{1}{2} |1\rangle_A |0\rangle_B - 
                         \frac{1}{2} |1\rangle_A |1\rangle_B    
\\
\\
&= \frac{1}{\sqrt{2}} |0\rangle_A |+\rangle_B + \frac{1}{\sqrt{2}} |1\rangle_A |-\rangle_B 
\end{aligned}                         
$$ 

here, $r_s = 2$, so the state is entangled

Furthermore, several measures of entanglement of pure bipartite states are functions of the Schmidt Coefficients, e.g.:
- Subsystem Purity: $\gamma = \sum_k \lambda_k^4 $
- von Neumann Entropy: $ S = - \sum_k \lambda_k^2 \log \left(\lambda_k^2 \right) $

In [19]:
import numpy as np
from IPython.display import Latex
from qiskit import QuantumCircuit
from qiskit.quantum_info import Statevector, schmidt_decomposition

In [20]:
ζ = Statevector([1/np.sqrt(2),0,0,1/np.sqrt(2)])
ζ.draw('latex',prefix='|\\zeta\\rangle = ')

<IPython.core.display.Latex object>

In [21]:
ζ_sd = schmidt_decomposition(ζ,[0])
print(ζ_sd)

[(0.7071067811865475, Statevector([1.+0.j, 0.+0.j],
            dims=(2,)), Statevector([1.+0.j, 0.+0.j],
            dims=(2,))), (0.7071067811865475, Statevector([0.+0.j, 1.+0.j],
            dims=(2,)), Statevector([0.+0.j, 1.+0.j],
            dims=(2,)))]


In [22]:
for i, (s,u,v) in enumerate(ζ_sd):
    print('---')
    display(Latex(f"$$\\lambda_{i} = {s}$$"))
    display(u.draw('latex',prefix=f'|u_{i}\\rangle = '))
    display(v.draw('latex',prefix=f'|v_{i}\\rangle = '))

---


<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

---


<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

In [23]:
w = Statevector([0,1/np.sqrt(3),1/np.sqrt(3),0,1/np.sqrt(3),0,0,0])
w.draw('latex',prefix='|w\\rangle = ')

<IPython.core.display.Latex object>

In [24]:
w_sd = schmidt_decomposition(w,[0,1])

In [25]:
for i, (s,u,v) in enumerate(w_sd):
    print('---')
    display(Latex(f"$$\\lambda_{i} = {s}$$"))
    display(u.draw('latex',prefix=f'|u_{i}\\rangle = '))
    display(v.draw('latex',prefix=f'|v_{i}\\rangle = '))

---


<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

---


<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

In [None]:
Statevector(sum([suv[0]*np.kron(suv[1],suv[2]) for suv in w_sd]))

The `schmidt_decomposition` works not only for qubit systems, for any bipartite qudit system.

For example, we can have the qutrit state::

$$ 
|\chi\rangle_{AB} = \frac{1}{\sqrt{3}} \bigg (|0_A0_B\rangle + |1_A1_B\rangle + |2_A2_B\rangle \bigg ) 
$$

In [27]:
χ = Statevector(np.array([1,0,0,0,1,0,0,0,1])*1/np.sqrt(3),dims=(3,3))
χ.draw('latex', prefix='|\\chi\\rangle = ',max_size=10)

<IPython.core.display.Latex object>

In [28]:
w_sd = schmidt_decomposition(χ,[0])
for i, (s,u,v) in enumerate(w_sd):
    print('---')
    display(Latex(f"$$\\lambda_{i} = {s}$$"))
    display(u.draw('latex',prefix=f'|u_{i}\\rangle = '))
    display(v.draw('latex',prefix=f'|v_{i}\\rangle = '))

---


<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

---


<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

---


<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

In [29]:
dims = (2,3,4)
k = Statevector(np.arange(np.prod(dims)),dims=dims)
k.draw('latex',max_size=24)

<IPython.core.display.Latex object>

In [31]:
k_sd = schmidt_decomposition(k,[0,2])
for i, (s,u,v) in enumerate(k_sd):
    print('---')
    display(Latex(f"$$\\lambda_{i} = {s}$$"))
    display(u.draw('latex',prefix=f'|u_{i}\\rangle = '))
    display(v.draw('latex',prefix=f'|v_{i}\\rangle = '))

---


<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

---


<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

In [32]:
Statevector(sum([suv[0]*np.kron(suv[1],suv[2]) for suv in k_sd]),dims=dims).draw('latex',max_size=24)

<IPython.core.display.Latex object>

### Issues

To perform the Schmidt Decomposition, the function `schmidt_decomposition` in qiskit performs the singular value decomposition (SVD) of a matrix. 

This is done using `numpy.linalg.svd`. 

A well-known problem of the SVD is that the sign selection of the singular vectors is ambiguous.

For example, the state:

$$ |\xi \rangle_{AB} = \frac{1}{2} |0\rangle_A |0\rangle_B + 
                         \frac{1}{2} |0\rangle_A |1\rangle_B + 
                         \frac{1}{2} |1\rangle_A |0\rangle_B + 
                         \frac{1}{2} |1\rangle_A |1\rangle_B  
$$
$ $                         
Has two equally-valid Schmidt decompositions:

$$
\begin{aligned}
|\xi  \rangle_{AB} &= |+\rangle_A |+\rangle_B
\\
\\
&\quad \text{or}
\\
\\
|\xi \rangle_{AB} &= \big (\text{-} |+\rangle_A \big) \big(\text{-} | +\rangle_B \big)
\end{aligned}
$$


In the second case, each vector has an addition global phase of $\text{-}1$.

In quantum computing, this is not a problem because global phases are unphysical. 

However, this not knowing which global phase one will get can cause issues when dealing with code that checks if statevectors are equal to each other.

In [33]:
ξ = Statevector.from_label('++')
ξ.draw('latex',prefix='|\\xi\\rangle = ')

<IPython.core.display.Latex object>

In [34]:
ξ_sd = schmidt_decomposition(ξ,[0])
for i, (s,u,v) in enumerate(ξ_sd):
    print('---')
    display(Latex(f"$$\\lambda_{i} = {np.round(s,6)}$$"))
    display(u.draw('latex',prefix=f'|u_{i}\\rangle = '))
    display(v.draw('latex',prefix=f'|v_{i}\\rangle = '))

---


<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

In [35]:
from schmidt_decomposition_local import schmidt_decomposition as sd_loc

In [36]:
ξ_sd = sd_loc(ξ,[0])
for i, (s,u,v) in enumerate(ξ_sd):
    print('---')
    display(Latex(f"$$\\lambda_{i} = {np.round(s,6)}$$"))
    display(u.draw('latex',prefix=f'|u_{i}\\rangle = '))
    display(v.draw('latex',prefix=f'|v_{i}\\rangle = '))

---


<IPython.core.display.Latex object>

<IPython.core.display.Latex object>

<IPython.core.display.Latex object>