In [None]:
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt

To revise what we have learnt till now, let's look at the following example

# Diffusion using random walks
The following example has been adapted from [scipylectures.org](https://scipy-lectures.org/intro/numpy/operations.html#basic-reductions).   

We will model the diffussion of a particle in a one dimensional gird using a random walk. The particle starts at the origin at $t=0$ and at each time step jumps right or left with equal probability. A step towards left is denoted by a displacement of `-1` units and a step towards right is `+1` units.

<img src="data/img/random_walk_1.png" height=100 width=450>

**We want to find the typical distance (in units of grid points) from the origin of a random walker after `t` left or right jumps**
To achieve this, we will generate a random trajectory for a walker. We will then generate a lot of such walks (lets call them *stories*) and check their statistical properties to find a pattern.    

The simulation will be done using NumPy array computing tricks: we are going to create a 2D array with the *stories* along one axis and time along another.

<img src="data/img/random_walk_schema_1.png" height=300px width=300px>

In [None]:
n_stories = 10000 # number of stories i.e the maximum number of independent walks
t_max = 200      # time during which we follow the walker

We will create the array of steps taken by the walkers shown in the above schema using the function `np.random.choice()`. The first argument will be a list of values from which the numbers will be chosen i.e. `[-1,1]`. The second argument will be a tuple denoting the shape of the array to be created.

In [None]:
steps = np.random.choice()      # COMPLETE THIS LINE OF CODE

We find the *displacement* from the origin for each of the walker as a function of time by calculating the cumulative sum of the steps **along the time axis** using `np.cumsum()`.

<img src="data/img/random_walk_schema_2.png" height=300px width=300px>

In [None]:
displacements = np.cumsum(steps, )     # COMPLETE THIS LINE OF CODE

We now find the root mean squared displacement as a function of time by calculating the statistic along the axis of the *stories*. You can use the `np.sqrt()`, `np.mean` and `**` functions and operations.

In [None]:
# COMPLETE THESE THREE LINES OF CODE
sq_displacement =                      #squared displacement
mean_sq_disp = np.mean()                            #mean squared displacement along the story axis
rms_disp =                                 # root mean squared displacement

Let's now plot our results. We generate an array containing the time steps and plot the RMS displacement versus the time steps. We also plot a $\sqrt{t}$ in the same plot.

In [None]:
# generate the time axis
t =             # COMPLETE THIS LINE OF CODE

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12,8))
plt.scatter(t, rms_disp, label = "Simulation", marker="x", c="C0")
plt.plot(t, np.sqrt(t), label = r"$t$", c="C1", lw=2)
plt.legend()
plt.xlabel(r"$t$", fontsize=20) 
plt.ylabel(r"$\sqrt{\langle (\delta x)^2 \rangle}$", fontsize=20) 

We find a well-known result in physics: the RMS distance grows as the square root of the time!   

To get a feel of how efficiently we did all the above calculations on such a huge number of elements, let us time the code used to do all the calculations. Paste the all code to do the calculations (except plotting) in the cell below:

In [None]:
%%timeit
# PASTE THE CODE HERE

For comparison we will do a very simple calculation on the same number of elements using native Python. I hope this helps you to appreciate vectorized calculations!

In [None]:
%%timeit

vals = [i for i in range(n_stories*t_max)]
new_vals = [i+1 for i in vals]

# Fancy Indexing

### Boolean arrays and logical operations

Just like `int` and `float` the elements of a NumPy array can also be boolean values i.e. `True` or `False`. These arrays may be created as a result of element wise comparison between two arrays.

In [None]:
a1 = np.array([1, 2, 3, 4])
b1 = np.array([4, 2, 2, 4])

In [None]:
a1 == b1

In [None]:
a1 >= b1

**NOTE:** if we want to check whether two arrays are identical to each other, we can use the function `np.array_equal()`

Element wise logical operations can be done using built in functions

In [None]:
a2 = np.array([True, True, False, False])
b2 = np.array([True, True, False, False])

In [None]:
np.logical_or(a2, b2)

In [None]:
a2 | b2

**NOTE:** In addition to the predefined functions shown above the binary operators `&`, `|` and `~` can also be used to determine the element wise logical AND, OR and NOT.   
When performing `sum()` on boolean arrays, the `True` values are treated as 1 and `False` as zero.

In [None]:
a2.sum()

### Indexing with boolean arrays

If instead of using integers we index arrays with other boolean arrays of same (or compatible) shape, the returned array will be composed of elements of the original array for which the corresponding boolean index was True. For example

In [None]:
a1

In [None]:
a1[[True, False, False, True]]

We may have an array of data where negative values indicate some kind of error. We can use a boolean *mask* to select array elements which satisfy our criteria

In [None]:
x = np.array([1.2, 2.8, 3.5, -999, 2.7, 4.8, -999])

mask = (x > 0)
mask

In [None]:
x[mask]

Generally it is done in a single step

In [None]:
x[x>0]

We can also set specific values for array elements which satisfy our criteria

In [None]:
x[x<0] = np.nan
x

**NOTE:** `np.nan` is a special data object (of type `float`) which is used to denote invalid or missing values. NumPy is built to gracefully handle invalid or missing data points as long as they are marked with `NaN` (Not a Number). This is the recommended way of doing this instead of the more traditional way of denoting missing data with absurd numbers. For convenience, NumPy has a host of such special constants defined which are listed [here](https://numpy.org/doc/stable/reference/constants.html?highlight=constants).

### Masking lines in a spectra

An `.npy` file (`data/sdss_spectra.npy`) has been provided containing the spectra of a galaxy. The `0th` axis corresponds to the wavelength sampling. The `1th` axes has 3 elements. The `0th` element is the wavelength grid, the `1th` element is the measured (and normalized) flux and the `2th` element gives the flux errors. Lets first read in the data and unpack them into separate arrays.

In [None]:
data_path = Path("./data/sdss_spectra.npy")
#COMPLETE THESE THREE LINES OF CODE
data =  #read in the data
wavelength = data #allocate the proper 1D slices
flux = data
flux_err = data

Lets plot the spectra

In [None]:
plt.figure(figsize=(16,8))
plt.plot(wavelength, flux, label="Flux")
plt.plot(wavelength, flux_err, label="Error")
plt.ylim(-1,2)
plt.xlabel("Wavelength [$\AA$]", fontsize=20)
plt.ylabel("Normalized Flux [Arbitrary Units]")
plt.legend()

Lets mask measurements having high errors. Lets reject all flux values for which the error is greater than 0.2. We will first create a boolean mask.

In [None]:
good_data_mask = () #COMPLETE THIS LINE OF CODE

Now lets use this mask to index our arrays.

In [None]:
#COMPLETE THESE TWO LINES OF CODE
good_wavelength = wavelength
good_flux = flux

Lets plot the array we selected.

In [None]:
plt.figure(figsize=(16,8))
plt.plot(good_wavelength, good_flux, label="Flux")
plt.plot(wavelength, flux_err, label="Error")
plt.ylim(-0.75,0.75)
plt.xlabel("Wavelength [$\AA$]", fontsize=20)
plt.ylabel("Normalized Flux [Arbitrary Units]")
plt.legend()

We can see emission and absorption lines in the above spectra. Lets try to get rid of them by keeping flux values which are less than $2\sigma$ away from the mean. For this you need to use the numpy functions `np.mean()`, `np.std()` and the `&` boolean operator. (Alternatively the `np.abs()` function can also be used).

In [None]:
#COMPLETE THIS LINE OF CODE
cont_mask = 

Lets see what we have selected using this mask

In [None]:
cont_wavelength = good_wavelength[cont_mask]
cont_flux = good_flux[cont_mask]

plt.figure(figsize=(16,8))
plt.plot(good_wavelength, good_flux, label="Flux", ls="--", alpha=0.8)
plt.plot(cont_wavelength, cont_flux, label="Clipped Flux")

plt.ylim(-0.75,0.75)
plt.xlabel("Wavelength [$\AA$]", fontsize=20)
plt.ylabel("Normalized Flux [Arbitrary Units]")
plt.legend()