# Plotting states in 3D python: Example C60
Python has a very rich ecosystem for all kinds of atomic simulation. `ASE` is somewhat common and probably has the simplest interface, so we will use it here. 

In [None]:
from pathlib import Path  # Builtin library. Utility for file paths.
import numpy as np
import ase
from ase.visualize import view as aseviewer
from ase.io import write, read
import ase.build
from ase.data import pubchem

In [None]:
# Shorthands
Inv = np.linalg.inv
Tr  = np.trace
MM  = np.matmul  

In [None]:
import nglview as nv

In [None]:
from nglview.contrib.movie import MovieMaker

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os

### Get molecule from ASE 
(check out other systems you can generate with ASE such as periodic graphene or nanoribbons)

In [None]:
C60 = ase.build.molecule("C60")
view = nv.show_ase(C60)
view

In [None]:
xyz = C60.get_positions() ## get the positions
xyz

### Try the ASE external viewer. 
You can mark atoms and remove these and do all sorts of modifications. Then you can save the structure as an xyz-file ("file.xyz") and read in again to the jupyter. 

In [None]:
aseview(C60)
#atoms= read("test.xyz")
#aseview(atoms)

In [None]:
# Generate nearest neighbor hopping Hamiltonian from coordinates 
def hamiltonian(xyz):
    a=1.4
    dist = np.linalg.norm(xyz[None, :, :] - xyz[:, None, :], axis=2)
    return np.where((dist < (a + 0.1)) & (dist > 0.1), -1, 0)
# Output a test
H = hamiltonian(xyz)


In [None]:
plt.spy(H) ## visualize H

In [None]:
np.sum(H,axis=0)  # Check: We have 3 neighbors/bonds for each Carbon atom: 

In [None]:
es, vs = np.linalg.eigh(H) # Eigenvalues and eigenstates

In [None]:
def showspectrum(xyz):
    H = hamiltonian(xyz)
    es, vs = np.linalg.eigh(H)
    f, ax = plt.subplots()
    eu, neu = np.unique(es.round(1), return_counts=True)
    ax.bar(eu, height=neu, width=0.1)
    ax.set_ylabel("degeneracy")
    ax.set_xlabel("energy")
    ax.grid(axis="y")


In [None]:
showspectrum(xyz)

In [None]:
# Plot states with nglview:
def plot_states3D(atoms,psi):
    view = nv.show_ase(atoms)
    view.clear_representations()
    # Add representations
    view.add_representation('line', selection='all')
    na = atoms.get_global_number_of_atoms()
    assert na == len(psi), "Atom/Basis not right!"
    view.add_representation('line', selection='all')
    scale = max(abs(psi))*10
    # Add a sphere for each atom with a specific size
    for ia,at in enumerate(atoms):
        if(psi[ia]>0):
            view.shape.add_sphere(at.position, [1, 0, 0], scale*abs(psi[ia]))
        else:
            view.shape.add_sphere(at.position, [0, 0, 1], scale*abs(psi[ia]))
    return view

In [None]:
plt.plot(es)

In [None]:
# Pick a state and view it with nglview:
istate=0
view = plot_states3D(C60, vs.T[istate])
view

In [None]:
view.download_image()

### Make 3D plot
This is to save and produce pictures and small animated gifs

In [None]:
from mpl_toolkits.mplot3d import Axes3D
from scipy.spatial.distance import pdist, squareform
from PIL import Image
from tqdm import tqdm  # For the progress bar

In [None]:
def plot3Dwf(xyz,wf,title):
    # Nearest-neighbor threshold (distance ~1.4 for bonds)
    threshold_distance = 1.5
    # Compute pairwise distances
    distances = squareform(pdist(xyz))
    # Find pairs of points within the threshold distance
    edges = np.argwhere((distances < threshold_distance) & (distances > 0))  # Exclude self-loops
    # Create the 3D plot
    fig = plt.figure(figsize=(5, 5))
    ax = fig.add_subplot(111, projection='3d')

    # Plot the atoms (points) with sizes based on wavefunction values
    for i, (x, y, z) in enumerate(xyz):
        if wf[i] > 0:
            ax.scatter(x, y, z, s=abs(wf[i])*2000, color='blue', alpha=0.6)  # Adjust size multiplier as needed
        else:
            ax.scatter(x, y, z, s=abs(wf[i])*2000, color='red', alpha=0.6)  # Adjust size multiplier as needed
    # Plot the bonds (edges)
    for i, j in edges:
        x_values = [xyz[i, 0], xyz[j, 0]]
        y_values = [xyz[i, 1], xyz[j, 1]]
        z_values = [xyz[i, 2], xyz[j, 2]]
        ax.plot(x_values, y_values, z_values, color='black', alpha=0.7)
    # Set labels and equal aspect ratio
    ax.set_axis_off()  # Hides all axes, including the background grid
    ax.set_box_aspect([1, 1, 1])  # Equal aspect ratio for all axes
    ax.view_init(elev=0, azim=30)
    ax.set_title(title)
    return fig
    


#### Generate 3D plots of all states and save these as png

In [None]:
frame_filenames = []
#!rm *.png
# Generate frames
for istate in tqdm(range(60), "Generating plots"):
    filename = f"state_{istate}.png"
    fig = plot3Dwf(xyz,vs.T[istate],"State "+str(istate)+", E="+str(np.round(es[istate],2)))
    plt.draw()
    plt.savefig(filename, dpi=100, bbox_inches='tight')
    frame_filenames.append(filename)
    plt.close(fig)

### Time-propagation
We calculate the "time-propagator" $\hat{U}(t)=e^{-i\hat{H}t}$ and to obtain $\psi(t) = \hat{U}(t)\psi_0$

In [None]:
t=0.1
U = MM(MM(vs,np.diag(np.exp(1.j*es*t))),Inv(vs)) 

In [None]:
psi0 = np.zeros(60)
psi0[0] = 1.

In [None]:
psit = psi0
allpsi = [psi0]
ntimesteps = 50
for it in range(ntimesteps):
    psit = MM(U,psit)
    allpsi.append(psit)

In [None]:
frame_filenames = []
!rm psi_*.png
# Generate frames
for it,psi in tqdm(enumerate(allpsi), "Generating plots"):
    filename = f"psi_{it}.png"
    fig = plot3Dwf(xyz,allpsi[it].real,"t="+str(it))
    plt.draw()
    plt.savefig(filename, dpi=100, bbox_inches='tight')
    frame_filenames.append(filename)
    plt.close(fig)


In [None]:
import imageio

In [None]:
with imageio.get_writer("psi-in-time.gif", mode='I', duration=0.15 ) as writer:
    for filename in tqdm(frame_filenames, desc="Creating GIF"):
        image = imageio.imread(filename)
        writer.append_data(image)
       

In [None]:
from IPython.display import Image

# Display the GIF
Image(filename="psi-in-time.gif")