In [None]:
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt

hbar = 1
m = 1
N = 400

try:
    from numpy import heaviside
    heaviside = heaviside
except:
    heaviside = lambda x,a: (x > 0) * 1

# Some does not have scipy.linalg.eigh_tridiagonal, so this takes care of that
# TODO: leave only one in the end
try:
    from scipy.linalg import eigh_tridiagonal
    get_eigh = lambda H: eigh_tridiagonal(np.diag(H), np.diag(H, k=1))
except ImportError:
    get_eigh = np.linalg.eigh

def _interval(x, a, b):
    """One in interval [a,b], 0 otherwise"""
    return heaviside(x - a, 1) * heaviside(b - x, 1)

def V_box(x):
    """Box potential"""
    return 0*x # We multiply by x so that numpy returns a list

def V_harmosc(x): 
    """Harmonic oscillator potential"""
    return 10000 * 1/2 * (x - 1/2)**2

def V_atom(x, pos=0.25, w=0.5):
    """Single well potential
    pos: beginning pos of well
    w  : width of well"""
    """if x > pos and x < pos + w:
        return -1000
    else:
        return 0"""
    return -1000 * _interval(x, pos, pos + w)

def V_molecule(x, w = 0.25, b = 0.25):
    """Double well potential
    w   : width of well
    b   : distance between wells"""
    return V_atom(x, pos = 0.5 - w - b/2, w = w) + V_atom(x, pos = 0.5 + b/2, w = w)


def V_crystal(x, N_w = 5, b=0.1):
    """Poly well potential
    N_w  : number of wells
    b    : width between wells"""
    w = (1.0 - N_w * b)/(N_w + 20.0) # Add 20 because we want to have 10w space on either side
    pos = (x - 10*w) % (w + b)
    return V_atom(pos, pos=0, w=w) * _interval(x, 10*w, 1 - 10*w)
    
def investigate_potential(V, N, states = range(5), **kwargs):
    dx = 1 / (N + 1)
    x = np.linspace(dx, N * dx, N)
    v = V(x, **kwargs)

    H = np.zeros((N, N))
    for i in range(0, N):
        H[i][i] = hbar**2 / (m * dx**2) + v[i]
    for i in range(0, N - 1):
        H[i][i+1] = -hbar**2 / (2 * m * dx**2)
        H[i+1][i] = -hbar**2 / (2 * m * dx**2)
    
    eigvals, eigvecs = get_eigh(H)
    energies = eigvals
    waves = eigvecs.T

    prod = eigvecs.T @ eigvecs
    print("Greatest deviation from orthonormality:")
    print(np.max(np.abs(eigvecs.T @ eigvecs - np.eye(N))))

    plt.figure(figsize=(16, 8))
    plt.title(V.__doc__.split('\n')[0] + ", discrete, N = %d points" % N)
    plt.xlabel("$x / L$")
    plt.ylabel("$V(x), E_n$")
    plt.plot(x, v, "grey", linestyle="dashed", linewidth=2, label="$V(x)$")
    for i in range(0, len(states)):
        n = states[i]
        energy = energies[n]
        wave = waves[n]
        plt.plot((0, 1), (energy, energy), "C%d:" % (i % 10), linewidth=1)
        plt.plot(x, energy + wave * 1000, "C%d-" % (i % 10), linewidth=1, label="$E_{%d}, \psi_{%d}$" % (n, n))
    plt.legend(loc="upper left")
    plt.show()
    
investigate_potential(V_box, N)
investigate_potential(V_harmosc, N)
investigate_potential(V_atom, N)
investigate_potential(V_molecule, N, b=0.05)
investigate_potential(V_crystal, N, N_w=10, b=0.05)