In [None]:
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt

hbar = 1
m = 1

# Some does not have scipy.linalg.eigh_tridiagonal, so this takes care of that
# TODO: leave only one in the end
try:
    from scipy.linalg import eigh_tridiagonal
    get_eigh = lambda H: eigh_tridiagonal(np.diag(H), np.diag(H, k=1))
except ImportError:
    get_eigh = np.linalg.eigh

def V_box(x):
    """Box potential"""
    return 0

def V_harmosc(x): 
    """Harmonic oscillator potential"""
    return 10000 * 1/2 * (x - 1/2)**2

def V_atom(x):
    """Single well potential"""
    if x > 0.25 and x < 0.75:
        return -1000
    else:
        return 0

def V_molecule(x):
    """Double well potential"""
    if x < 0.5:
        return V_atom(2 * (x - 0.0))
    else:
        return V_atom(2 * (x - 0.5))
    
def investigate_potential(V, N, states):
    dx = 1 / (N + 1)
    x = np.linspace(dx, N * dx, N)
    v = [V(x[i]) for i in range(0, N)]

    H = np.zeros((N, N))
    for i in range(0, N):
        H[i][i] = hbar**2 / (m * dx**2) + v[i]
    for i in range(0, N - 1):
        H[i][i+1] = -hbar**2 / (2 * m * dx**2)
        H[i+1][i] = -hbar**2 / (2 * m * dx**2)
    
    eigvals, eigvecs = get_eigh(H)
    energies = eigvals
    waves = eigvecs.T

    prod = eigvecs.T @ eigvecs
    print("Greatest deviation from orthonormality:")
    print(np.max(np.abs(eigvecs.T @ eigvecs - np.eye(N))))

    plt.figure(figsize=(16, 8))
    plt.title(V.__doc__ + ", discrete, N = %d points" % N)
    plt.xlabel("$x / L$")
    plt.ylabel("$V(x), E_n$")
    plt.plot(x, v, "k-", linewidth=3, label="$V(x)$")
    for i in range(0, len(states)):
        n = states[i]
        energy = eigvals[n]
        wave = waves[n]
        plt.plot((0, 1), (energy, energy), "C%d:" % (i % 10), linewidth=1)
        plt.plot(x, energy + wave * 1000, "C%d-" % (i % 10), linewidth=1, label="$E_{%d}, \psi_{%d}$" % (n, n))
    plt.legend(loc="upper left")
    plt.show()
    
investigate_potential(V_box, 200, range(0, 5))
investigate_potential(V_harmosc, 200, range(0, 5))
investigate_potential(V_atom, 200, range(0, 5))
investigate_potential(V_molecule, 200, range(0, 5))