# Stuff to do
* limit inputs (z>0, maybe 0 <= theta <= pi/3)
* fix smart_cubic_k for annoying people
* Figure out what to do with scattered wave vectors parallel to the plane (k_in = (0, 0, -2pi), for example).
* Check whether or not they really are parallel to the plane (by hand)



In [31]:
import cmp
import pdir
from lattices import *
from scattering import *
%matplotlib notebook
np.seterr(invalid='raise')
np.set_printoptions(threshold=np.nan)

In [2]:
# Inputs
eq = np.isclose
# Lattice vectors (3 vectors of length 3)
a = 1
b = 2
a1 = np.array([1, 0, 0])
a2 = np.array([0, 1, 0])
a3 = np.array([0, 0, 1])
theta = 80*np.pi/180

# Array of basis vectors
basis = np.array([[0,0,0],[0.5,0.5,0],[0.5,0,0.5],[0,0.5,0.5]])
# Colors for each of the basis vectors
blargh = ('r', 'r','b','b')
# Size multiplier for each of the atoms. Default is 1
sizes = (2,2,1,1)
verbose = True


# Gridline type:
# Soft: Lines along cartesian axes. Takes into account nonequal lattice spacing
# LatticeVectors: Lines along the latticevectors (only on lattice points)
GridType = "lattice"

# Limit type:
# individual: Sets the limits as max(nx*a1,ny*a2,nz*a3), so we include nx unitcells in the a1 direction, etc.
# sum: Sets the limits r_min = n_min*[a1 a2 a3] and likewise for n_max
LimType = "dynamic"
Maxs = [2,2,2]
Mins = [0,0,0]

LatticeType = "conventional fcc"

#Lattice(lattice_name = LatticeType, colors = blargh, sizes = sizes, max_ = Maxs, verbose=True)
#Reciprocal(lattice_name=LatticeType, indices=(1,1,0))

In [44]:
def setup_scattering(lattice_name='simple cubic', k_in=np.array([0,0,-3*np.pi]), scattering_length=np.array([1,1,1,1]), points=None, intensities=None, indices=None, verbose=False):

    delete_negative = True
    delete_zero = True
    project_sphere = True
    sphere_alpha = 0.1
    sphere_radius = np.sqrt(8)
    sphere_points = 20
    point_sizes = 2
    
    
    lattice_name = lattice_name.lower()
    beam_length = 2
    beam_end_z = 1
    min_, max_ = (-2, -2, -1), (2, 2, 1)
    grid_type = latticelines[lattice_name]
    lim_type = "proper"
    lattice_colors = ["xkcd:cement",
              "xkcd:cornflower blue",
              "xkcd:cornflower blue",
              "xkcd:cornflower blue"]
    lattice_sizes = [1, 1, 1, 1]
    g_col = 'k'
    g_w = 0.5
    g_a = 0.6
    size_default = 36
    point_sizes *= size_default
    
    (a1, a2, a3), basis, _ = chooser(lattice_name, verbose=verbose)
    r_min, r_max, n_min, n_max = find_limits(lim_type, a1, a2, a3,
                                                      min_, max_)
    (atomic_positions, lattice_coefficients, atomic_colors, atomic_sizes,
     lattice_position) = generator(a1, a2, a3, basis, lattice_colors, lattice_sizes,
                                            lim_type, n_min, n_max, r_min,
                                            r_max)
    objects = [atomic_positions, lattice_coefficients, atomic_colors,
               atomic_sizes, lattice_position]
    objects = limiter(atomic_positions, objects, r_min, r_max)
    (atomic_positions, lattice_coefficients, atomic_colors, atomic_sizes,
     lattice_position) = objects
    
    pruned_lines = grid_lines(a1, a2, a3, atomic_positions,
                                       lattice_position, grid_type,
                                       verbose=verbose)
    
    
    # Scattering stuff
    intensities, k_out, indices = calc_scattering(a1, a2, a3, basis, scattering_length, k_in)
    if delete_negative and delete_zero:
        negatives = k_out[:,2] < 0
        if delete_zero:
            negatives = k_out[:,2] <= 0
        
        index = np.array(range(k_out.shape[0]))
        negative_index = index[negatives]
        intensities = np.delete(intensities, negative_index, 0)
        k_out = np.delete(k_out, negative_index, 0)
        indices = np.delete(indices, negative_index, 0)
    
    if project_sphere:
        points = projection_sphere(k_out, r=sphere_radius)
    else:    
        points = projection(k_out)
    
    plot_points, plot_intensities, plot_indices = prune_scattered_points(points, intensities, indices)
    
    # Normalize intensities
    plot_intensities /= np.amax(plot_intensities)
    # Create the color array
    plot_colors = np.zeros((plot_intensities.size, 4))
    plot_colors[:, 3] = plot_intensities

    # Plotting
    fig = plt.figure()
    ax = fig.gca(projection="3d")
    
    if not project_sphere:
        ax.set_position([0, 0, 0.8, 0.8])

    # Plot atoms
    ax.scatter(atomic_positions[:, 0], atomic_positions[:, 1],
               atomic_positions[:, 2], c=atomic_colors, s=atomic_sizes)
    
    for line in pruned_lines:
        ax.plot(line[0], line[1], line[2], color=g_col, linewidth=g_w, alpha=g_a)
    
    # Plotting the beam: First we create the beam display vector
    k_disp = beam_length * k_in / lattices.mag(k_in)
    beam = np.array([[-k_disp[0], 0], [-k_disp[1], 0], [-k_disp[2] + beam_end_z, beam_end_z]])
    ax.plot(beam[0], beam[1], beam[2], linewidth=2, color='b')
    
    ax.set_aspect('equal')
    ax.set_proj_type('ortho')
    ax.set_xlim([r_min[0], r_max[0]])
    ax.set_ylim([r_min[1], r_max[1]])
    ax.set_zlim([r_min[2], r_max[2]])
    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.grid(False)
    plt.axis('off')
    plt.axis('equal')
    
    if project_sphere:
        phi = np.linspace(0, 2 * np.pi, sphere_points)
        theta = np.linspace(0, np.pi / 2, sphere_points)
        x = sphere_radius * np.outer(np.cos(phi), np.sin(theta))
        y = sphere_radius * np.outer(np.sin(phi), np.sin(theta))
        z = sphere_radius * np.outer(np.ones(np.size(phi)), np.cos(theta)) + 1
        ax.plot_surface(x, y, z, alpha=sphere_alpha)
        ax.scatter(plot_points[:, 0], plot_points[:, 1], plot_points[:, 2], c=plot_colors, s=point_sizes)
        for i in range(len(plot_indices)):
            x, y, z = plot_points[i, ]
            s = plot_indices[i]
            ax.text(x, y, z, s)
        
    else:
        ax2 = plt.axes([0.725, 0.725, 0.25, 0.25])
        ax2.tick_params(axis="both", labelbottom=False, labelleft=False)
        ranges = (np.amax(plot_points, axis=0) - np.amin(plot_points, axis=0))[:-1]
        ax2.scatter(plot_points[:, 0], plot_points[:, 1], c=plot_colors)
        for i in range(len(plot_indices)):
            x, y = plot_points[i, 0:2] - 0.05 * ranges
            s = plot_indices[i]
            ax2.text(x, y, s, va ='top', ha='right')


def prune_scattered_points(points, intensities, indices):
    """
    We prune the array of points, so we're left with unique points. The intensities are adjusted accordingly, and the indices are made into a list of strings, with one element per unique point
    """
    # First we get the unique points, the index of the new points from the old array, and a list of the inverse ids (list of n elements, where n is the number of points. Each element is the id of the corresponding unique point in the new array)
    unique_points, ids, inverse = np.unique(points, axis=0, return_index=True, return_inverse=True)
    
    # next we create an array of zeros for the new intensities
    new_intensities = np.zeros(ids.shape)
    
    # And a empty list with a guarantied number of elements
    new_indices = [None] * new_intensities.size
    
    for i in range(intensities.size):
        # The id of the new, unique point
        new_id = inverse[i]
        # A tuple of the old index row
        index = tuple(indices[i])
        # We add the old intensity from a given point, to the new unique points intensity
        new_intensities[new_id] += intensities[i]
        
        # We populate the list of new indices, taking care of whether or not an element of the list is empty
        if new_indices[new_id] is None:
            new_indices[new_id] = "{}".format(index)
        else:
            new_indices[new_id] = "{}\n{}".format(new_indices[new_id], index)
            
    return unique_points, new_intensities, new_indices


def projection_sphere(k_array, r=5, o=np.array([0,0,1])):
    """
    calculates the (positive) intersection between the line defined by o and k_array, and the sphere with radius r and origin o
    """
    d = r/mag(k_array)
    D = np.vstack((d, d, d)).T
    p = o + k_array * D
    return p
    

#setup_scattering(k_in=np.array([0,0,-3*np.pi]))