# TDVP numerical test \#1

The goal is to use the TDVP to simulate the time evolution of a simple Hamiltonian, when state space is restricted to some manifold Furthermore a goal is to compare this method with directly solving the SE and projecting to the manifold.

The first example will be $\mathcal{M}= S^2$ and $H=\sigma_x$ to generate an X-rotation of a 1-qubit system. Note that the Bloch-sphere is a restriction of the entire 2-dim $\mathcal{H}$ onto a sphere. 

In [23]:
from math import *
import numpy as np
import scipy as sp

import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt

import qutip as qt

In [4]:
H = qt.sigmax().full()

In [31]:
def psi(params:np.array):
    theta, phi = params
    ket0=qt.basis(2,0).full()
    ket1=qt.basis(2,1).full()
    assert -np.pi<=theta<np.pi, f"theta={theta} out of bounds."
    assert 0<=phi<2*np.pi, f"phi={phi} out of bounds."
    return np.cos(theta/2)*ket0 + np.exp(1j*phi)*np.sin(theta/2)*ket1

In [32]:
psi([pi/2,0])

array([[0.70710678+0.j],
       [0.70710678+0.j]])

In [33]:
jacpsi = jax.jacfwd(psi)

In [34]:
z=np.array([1.1,2.2])

In [36]:
jacpsi(z)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray(0.550000011920929, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(0.55, dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<BatchTrace(level=1/0)> with
    val = DeviceArray([0.5, 0. ], dtype=float32)
    batch_dim = 0
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

In [41]:
def f(X):
    x,y = X 
    return np.array([x**2+y, 5*+np.sin(y)]) 

In [43]:
jf = jax.jacfwd(f)
X=np.array([1.0,1.0])
jf(X)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray(1.0, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(1., dtype=float32)
  tangent = Traced<ShapedArray(float32[])>with<BatchTrace(level=1/0)> with
    val = DeviceArray([0., 1.], dtype=float32)
    batch_dim = 0
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError