In [3]:
from matplotlib import pyplot as plt
import numpy as np

In [4]:
def wavefunc_visualiser(psi, pt_index, size=61, orbital_no=8):
    '''psi - boundstate wavefunction object. In general a (L, N) matrix (L being the size of the scattering region
       Hamiltonian and N the number of solutions returned by the boundstate algorithm). 
       
       pt_index -  an index that selects the eigenvector corresponding to the N'th boundstate energy. Could be irrelevant
       for your purposes, in which case treat "psi[:, pt_index]" as your input wavefunction (if you only have one
       wavefunction)
       
       size - number of lattice sites in the scattering region Hamiltonian
       
       orbital_no - number of orbitals per site.
       
    '''
    wavefunc = psi[:, pt_index]
    orbitals = []
    labels = [r'$|\psi_{0, \uparrow}^{e}|^2$', r'$|\psi_{0, \downarrow}^{e}|^2$', r'$|\psi_{1, \uparrow}^{e}|^2$', r'$|\psi_{1, \downarrow}^{e}|^2$', r'$|\psi_{0, \uparrow}^{h}|^2$', r'$|\psi_{0, \downarrow}^{h}|^2$', r'$|\psi_{1, \uparrow}^{h}|^2$', r'$|\psi_{1, \downarrow}^{h}|^2$']
    for i in range(orbital_no):
        orbitals.append(wavefunc[i::orbital_no])
    
    fig, axes = plt.subplots(int(orbital_no/2), 2, figsize=(15,15))
    for i in range(len(axes.ravel())):
        axes.ravel()[i].plot(np.arange(0, size), np.abs(orbitals[i])**2.)
        axes.ravel()[i].set(xlabel='Site indices', ylabel=labels[i])
        axes.ravel()[i].xaxis.label.set_fontsize(16)
        axes.ravel()[i].yaxis.label.set_fontsize(18)
    
    #print orbital probs
    orbital_probs = []
    for i in range(orbital_no):
        orbital_prob = np.sum(np.abs(orbitals[i])**2.)
        orbital_probs.append(orbital_prob)
        print(orbital_prob)
    
    #print total prob
    orbital_probs = np.array(orbital_probs)
    print(np.sum(orbital_probs))
    
    return orbital_probs