<a href="https://colab.research.google.com/github/RodrigoAVargasHdz/JAX_projects/blob/main/GD_harmonic_oscillator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install --upgrade jax jaxlib

In [17]:
import jax.numpy as jnp
import jax
from jax import grad, value_and_grad
from jax import vmap
from jax import random
from jax.experimental import stax
from jax.experimental import optimizers

import matplotlib.pyplot as plt


**Kinetic energy in the DVR basis**

In [4]:
def kinetic_energy_jax(dx,m):
	'''
	dx --> delta x in the grid for position
	m --> size of the grid
	'''
	k = jnp.pi/dx
	k2 = k**2
	def jax_tri(m):
		M = jnp.zeros((m,m))
		for i in range(-m+1,m):
			M += i*jnp.eye(m, k=i)
		return M

	tt = jax_tri(m)

	tt_dn = tt + 1.E-5*jnp.eye(m)
# 	off_diag = jnp.divide(jnp.power(-1,tt),jnp.power(tt_dn,2))
	off_diag = jnp.divide(jnp.power(-1.*jnp.ones(tt.shape),tt),jnp.power(tt_dn,2))
	off_diag = (2.*k2/jnp.pi**2) * jnp.multiply((jnp.ones(m) - jnp.eye(m)),off_diag)

	T = (k2/3.) * jnp.eye(m) + off_diag
	return T/2.

**Potential energy function**

In [5]:
# Harmonic oscilator
def pes_function_ho(param, x):
		m = 1.0
		w = m*jnp.sqrt(param/m)/2.
		v = w * x**2
		return v

**Target spectrum**


In [None]:
nx = 250
x = jnp.linspace(-25.,25.,nx)

T = kinetic_energy_jax(x[1] - x[0],nx) # Kinetic energy
v_vect_ho = vmap(pes_function_ho,(None,0)) # Potential energy (vmap function)

# HO freq (omega)
w = jnp.ones(1)

v = v_vect_ho(w,x) # Potential energy (vector)
V = jnp.diag(v[:,0]) # Potential energy (Matrix)

# Target Hamiltonian
H0 = T + V 
# Diagonalization
e0,ev0 = jnp.linalg.eigh(H0)

In [None]:
plt.plot(jnp.arange(e0.shape[0]),e0)
plt.xlabel('Eigen energy [n]')
plt.ylabel('Energy')

**Optimization**

Let's optimize $\omega$ with respect to some target spectrum.

\begin{equation}
\omega^* = \arg\min_{\omega} \;\; {\cal L}
\end{equation}

Error function,

\begin{equation}
{\cal L} (\omega) = \sum_i^N \Big ( \epsilon_i(\omega) - \hat{\epsilon}_i\Big )^{2}
\end{equation}



**Gradient descent (GD)**

Quick recap of GD,

\begin{equation}
\omega_{t} = \omega_{t-1} - \eta \frac{\partial {\cal L} }{\partial \omega} 
\end{equation}

How do we compute $\frac{\partial {\cal L} }{\partial \omega} $?
1. Finite difference
2. Close form (only for some systems)
3. Automatic differentiation

In [21]:
	def loss(w):
		v = fnn_vect(w,x)
		V = jnp.diag(v[:,0])
		H = T + V
		e,_ = jnp.linalg.eigh(H)
		j0 = 50 #only the first j0 eigestates
# 		z = jnp.sum((e - e0)**2)
		z = (jnp.mean(jnp.diff(e[:j0])) - 1.)**2	
		return z



In [None]:
fnn_vect = vmap(pes_function_ho,(None,0)) # vectorize function
 
 #	Optimization	
m = 1 #NN (m=31 HO)
key = random.PRNGKey(0)
w0 = random.uniform(key, shape=(m,))*3

print('Initial random omega = {}'.format(w0))

#learning_rate
lr = 0.1 

for itr in range(100):
    val_loss, grad = value_and_grad(loss)(w0)
    w0 = w0 - lr * grad
    print(itr,val_loss,w0)
    
    vt = fnn_vect(w0,x)
    Vt = jnp.diag(vt[:,0])
    Ht = T + Vt
    et,evt = jnp.linalg.eigh(Ht)



    plt.clf()
    plt.subplot(1, 2, 1)
    plt.plot(e0,label = 'exact')
    plt.plot(et, label = 'model')
    plt.plot(jnp.arange(nx)+0.5, color='red', ls = '--')
    plt.axvline(50,color='k', ls = '--')
    plt.legend()
    
    plt.subplot(1,2,2)
    plt.text(-5,10, r'$\kappa$ = {0:.3f}'.format(w[0]))
    plt.plot(x,vt,label='ho')
    plt.plot(x,v,label='model')
    plt.legend()#loc=4

    plt.draw()
    plt.pause(0.1)

