# Stuff to do
* limit inputs (z>0, maybe 0 <= theta <= pi/3)
* fix smart_cubic_k for annoying people
* Figure out what to do with scattered wave vectors parallel to the plane (k_in = (0, 0, -2pi), for example).
* Widescreen figure window and subplots
* Detection plane and outgoing wavevectors in main plot window
* make sure FCC and BCC work
* Input sanitization
* in final app, give them just the allowed options (dropdown menu, for example)

In [4]:
from cmp import *
import pdir
%matplotlib notebook
#np.seterr(invalid='raise')
np.set_printoptions(threshold=np.nan)

In [5]:
# Inputs
eq = np.isclose
# Lattice vectors (3 vectors of length 3)
a = 1
b = 2
a1 = np.array([1, 0, 0])
a2 = np.array([0, 1, 0])
a3 = np.array([0, 0, 1])
theta = 80*np.pi/180

# Array of basis vectors
basis = np.array([[0,0,0],[0.5,0.5,0],[0.5,0,0.5],[0,0.5,0.5]])
# Colors for each of the basis vectors
blargh = ('r', 'r','b','b')
# Size multiplier for each of the atoms. Default is 1
sizes = (2,2,1,1)
verbose = True


# Gridline type:
# Soft: Lines along cartesian axes. Takes into account nonequal lattice spacing
# LatticeVectors: Lines along the latticevectors (only on lattice points)
GridType = "lattice"

# Limit type:
# individual: Sets the limits as max(nx*a1,ny*a2,nz*a3), so we include nx unitcells in the a1 direction, etc.
# sum: Sets the limits r_min = n_min*[a1 a2 a3] and likewise for n_max
LimType = "dynamic"
Maxs = [2,2,2]
Mins = [0,0,0]

LatticeType = "conventional fcc"

#Lattice(lattice_name = LatticeType, colors = blargh, sizes = sizes, max_ = Maxs, verbose=True)
#Reciprocal(lattice_name=LatticeType, indices=(1,1,0))

In [7]:
def setup_scattering(lattice_name='simple cubic', k_in=np.array([0,0,-3*np.pi]), scattering_length=np.array([1,1,1,1]), highlight=None, verbose=False):
    point_sizes = 2
    detector_screen_position = [0.6, 0.2, 0.3, 0.6]

    lattice_name = lattice_name.lower()
    beam_length = 2 * np.pi / lattices.mag(k_in)
    beam_end_z = 1
    min_, max_ = (-2, -2, -1), (2, 2, 1)
    grid_type = lattices.latticelines[lattice_name]
    lim_type = "proper"
    lattice_colors = ["xkcd:cement",
              "xkcd:cornflower blue",
              "xkcd:cornflower blue",
              "xkcd:cornflower blue"]
    lattice_sizes = [1, 1, 1, 1]
    g_col = 'k'
    g_w = 0.5
    g_a = 0.6
    size_default = 36
    point_sizes *= size_default
    
    (a1, a2, a3), basis, _ = lattices.chooser(lattice_name, verbose=verbose)
    r_min, r_max, n_min, n_max = lattices.find_limits(lim_type, a1, a2, a3,
                                                      min_, max_)
    (atomic_positions, lattice_coefficients, atomic_colors, atomic_sizes,
     lattice_position) = lattices.generator(a1, a2, a3, basis, lattice_colors, lattice_sizes,
                                            lim_type, n_min, n_max, r_min,
                                            r_max)
    objects = [atomic_positions, lattice_coefficients, atomic_colors,
               atomic_sizes, lattice_position]
    objects = lattices.limiter(atomic_positions, objects, r_min, r_max)
    (atomic_positions, lattice_coefficients, atomic_colors, atomic_sizes,
     lattice_position) = objects
    
    pruned_lines = lattices.grid_lines(a1, a2, a3, atomic_positions,
                                       lattice_position, grid_type,
                                       verbose=verbose)
    
    
    # Scattering stuff
    intensities, k_out, indices = scattering.calc_scattering(a1, a2, a3, basis, scattering_length, k_in)
    points = scattering.projection(k_out)
    
    # I assume the points are unique, now that I have deleted the ones pointing into the crystal
    #points, intensities, indices = scattering.prune_scattered_points(points, intensities, indices)
    
    # Normalize intensities
    intensities /= np.amax(intensities)
    # Create the color array
    colors = np.zeros((intensities.size, 4))
    colors[:, 3] = intensities

    if highlight is not None:
        high_index = np.array(highlight)
        num_ints = high_index.shape
        extra = 0
        if num_ints != (3,):
            print("We need 3 and only 3 indices! Highlighting nothing")
        else:
            indices_index = np.where((indices == high_index).all(axis=1))[0]
            if indices_index.shape != (1,):
                print("There is no scattering along {}".format(highlight))
            else:
                d, planes = lattices.reciprocal(a1, a2, a3, high_index, r_min - extra, r_max + extra, points=20)
                planes = lattices.plane_limiter(planes, r_min - extra, r_max + extra)
                high_point = points[indices_index]
                high_intensity = intensities[indices_index]
                colors[indices_index] = [1, 0, 0, high_intensity]
    
    
    # Plotting
    fig = plt.figure(figsize=(12.8, 4.8))
    ax = fig.gca(projection="3d")
    ax.set_position([0, 0, 0.5, 1])

    # Plot atoms
    ax.scatter(atomic_positions[:, 0], atomic_positions[:, 1],
               atomic_positions[:, 2], c=atomic_colors, s=atomic_sizes)
    
    for line in pruned_lines:
        ax.plot(line[0], line[1], line[2], color=g_col, linewidth=g_w, alpha=g_a)
    
    # Plotting the beam: First we create the beam display vector
    k_disp = beam_length * k_in / lattices.mag(k_in)
    beam = np.array([[-k_disp[0], 0], [-k_disp[1], 0], [-k_disp[2] + beam_end_z, beam_end_z]])
    ax.plot(beam[0], beam[1], beam[2], linewidth=2, color='b')
    ax2 = plt.axes(detector_screen_position)
    ax2.tick_params(axis="both", labelbottom=False, labelleft=False)
    ranges = (np.amax(points, axis=0) - np.amin(points, axis=0))[:-1]
    ax2.scatter(points[:, 0], points[:, 1], c=colors)
    for i in range(len(indices)):
        x, y = points[i, 0:2] - 0.05 * ranges
        s = indices[i]
        c = colors[i, :-1]
        ax2.text(x, y, s, color=c, va ='top', ha='right')
                
    try:
        for p in planes:
            ax.plot_surface(p[0], p[1], p[2], color="r", shade=False, alpha=0.2)
    except NameError:
        pass
    
    ax.set_aspect('equal')
    ax.set_proj_type('ortho')
    ax.set_xlim([r_min[0], r_max[0]])
    ax.set_ylim([r_min[1], r_max[1]])
    ax.set_zlim([r_min[2], r_max[2]])
    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.grid(False)
    ax.axis('off')
    ax.axis('equal')
    

k_in = scattering.smart_cubic_k(theta=np.pi/2, phi=0, indices=(0,0,1))
#setup_scattering(k_in=np.array([0,0,-2*np.pi]), highlight=(0,0,2))