# Computational Astrophysics 2021
---
## Eduard Larrañaga

Observatorio Astronómico Nacional\
Facultad de Ciencias\
Universidad Nacional de Colombia

---

## 03. Cross-validation of Decision Trees
### About this notebook

In this worksheet we will introduce the cross-validation of decision trees.

---

So far, we have used the median of the differences between the prediction and the target values to validate the decision tree model. This method, in which  we split the data in two sets (train and test), is known as **hold-out validation**  and it is the most basic form of validation. The measured accuracy (median of differences) will depend on how the data is splitted into subsets.


Now, we will introduce a better validation method, called **k-fold cross-validation**. This is similar to hold-out except that we split the data into k-subsets and we train and test the model k-times, using different combinations of the subsets and recording the accuracy each time (i.e. we perform the hold-out validation k times). 

In practice, each time we use a different combination of k-1 subsets to train the model and the final kth subset to test. Then, we take the average of the k accuracy measurements to be the overall accuracy of the the model.

In this worksheet we will use the same dataset of galaxies considered in previous worksheets. However, we will also estimate how accurate is the model when applied to Quasi-Stellar Objects (QSOs) compared with other galaxies. As you may know, QSOs are galaxies that have an Active Galactic Nucleus (AGN), which makes the galaxy brighter and therefore, they are detectable with the SDSS instruments at much higher redshifts.

### Loading the Data

As before, we will use the dataset provided as a NumPy strctured array in a binary format (.npy) called 'sdss_galaxy_colors.npy'. 


In [None]:
import numpy as np

In [None]:
path='' #Define an empty string to use in case of local working

In [None]:
# Working with google colab needs to mount the Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# we define the path to the files
path = '/content/drive/MyDrive/Colab Notebooks/CA2021/11. Decision Trees/presentation/'

In [None]:
data = np.load(path+'sdss_galaxy_colors.npy')
data

array([(19.84132, 19.52656, 19.46946, 19.17955, 19.10763, b'QSO', 0.539301  , 6.543622e-05),
       (19.86318, 18.66298, 17.84272, 17.38978, 17.14313, b'GALAXY', 0.1645703 , 1.186625e-05),
       (19.97362, 18.31421, 17.47922, 17.0744 , 16.76174, b'GALAXY', 0.04190006, 2.183788e-05),
       ...,
       (19.82667, 18.10038, 17.16133, 16.5796 , 16.19755, b'GALAXY', 0.0784592 , 2.159406e-05),
       (19.98672, 19.75385, 19.5713 , 19.27739, 19.25895, b'QSO', 1.567295  , 4.505933e-04),
       (18.00024, 17.80957, 17.77302, 17.72663, 17.7264 , b'QSO', 0.4749449 , 6.203324e-05)],
      dtype=[('u', '<f8'), ('g', '<f8'), ('r', '<f8'), ('i', '<f8'), ('z', '<f8'), ('spec_class', 'S6'), ('redshift', '<f8'), ('redshift_err', '<f8')])

In this kind of data structure, the `dtype` attribute corresponds to the name of the features. For our example, we identify the following:

| dtype | Feature|
|:-:|:-:|
|`u` |u band filter|
|`g` |g band filter|
|`r` |r band filter|
|`i` |i band filter|
|`z` |z band filter|
|`spec_class` |spectral class|
|`redshift` |redshift|
|`redshift_err` |redshift error|


The number of samples (galaxies) in this dataset is

In [None]:
n = data.size
n

50000

In order to implement the decision tree, we will use the functions defined in the previous worksheet to define the features (4 color indices) and the targets (redshift)

In [None]:
# Function returning the 4 color indices and the redshifts

features, targets = ...


### K-Fold

We will use the `sklearn.model_selection.KFold` function is designed to split the data into training and testing subsets. It does this by offering an iterable object that can be initialised with

```
kf = KFold(n_splits=k, shuffle=True)
```

where the argument `n_splits=k` specifies the number of subsets to use. On the other hand, the argument `shuffle` is set to false by default, but it is generally a good practice to shuffle the data for cross validation (it is usual that, during collection and storage, data of a similar type is stored adjacently and this would lead to some learning bias when training the tree). Complete information about this function is available at

https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html

Now, we will the `KFold` function to split the dataset of galaxies into k−1 training subsets and one remaining test subset. The first step is to inizialise the function with, for example, **k=5**. We will also inizialise the decision tree  with a `max_depth=19` (according to our previous results w.r.t. over-fit).

In [None]:
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import KFold

kf = KFold(n_splits=5, shuffle=True)
dec_tree = DecisionTreeRegressor(max_depth=19)


Using the `.split()` method, applied to the features set, it is possible to generate the **train and test indices**. Note that this defines just the indices corresponding to each subset, but not the subset itself. Thus, the subsets must be defined using those indices. Once the subsets are defined, they are used to train the decision tree and to evaluate its prediction using the median of the differences. 

The whole process od defining subsets, training the tree and evaluating the prediction must be repeated for for each of the k-iterations. Hence we implement these actions with a `for` loop as follows,



In [None]:
# declare an array for predicted redshifts from each iteration
all_predictions = np.zeros_like(targets)

for train_indices, test_indices in kf.split(features):
  train_features, test_features = features[train_indices], features[test_indices]
  train_targets, test_targets = targets[train_indices], targets[test_indices]

  # Train the decision tree
  dec_tree.fit(train_features, train_targets)
  
  # Predict using the model
  predictions = dec_tree.predict(test_features)

  # put the predicted values in the all_predictions array defined above
  all_predictions[test_indices] = predictions



# Evaluate the model using the median of differences of all_predictions
eval_dec_tree = ...  
eval_dec_tree

Once the model is trained and the predictions are calculated, we will compare graphically these predictions with the targets.


In [None]:
from matplotlib import pyplot as plt
%matplotlib inline

# plot the results to see how well our model looks
plt.figure()
plt.scatter(targets, all_predictions, s=0.4)
plt.xlim((0, targets.max()))
plt.ylim((0, predictions.max()))
plt.xlabel('Measured Redshift')
plt.ylabel('Predicted Redshift')
plt.show()


The above plot should look like the following

<center>
<img src="https://groklearning-cdn.com/modules/SjroKib6Hs5Fqxq53Vxme9/predicted_v_measured.png" width=450>
</center>


Note that in the plot of the predicted vs measured redshifts there is a good behavior for many of the galaxies but there are also many outliers (point out of the line).

---
### Spectral Class
The 'spec_class' feature in the dataset involves two values:, b'GALAXY' and b'QSO', which identify galaxies and Quasi-Stellar Objects (QSOs), respectively.

**1. Define a function that classifies the samples according to the 'spec_class'.**

**2. How many galaxies and how many QSOs are there in the dataset?**


**3. Calculate the median of the differences for galaxies and QSOs. What are the maximum values of these differences for each class of objects?**


Galaxies are not as bright as QSOs, so they become too faint to be detected with SDSS at redshifts greater than 0.4. This creates a measurement bias.

**4. Make a plot of the median of differences vs. meaasured redshift for all the objects in the dataset, using one color for galaxies and other color for QSOs.**


Now, you will plot again predicted vs measured redshifts, including a color for galaxies and other color for QSOs. The result will look like the following plot

<center>
<img src="https://groklearning-cdn.com/modules/ovFSymwFkqBPAcjnbSUxLG/predicted_actual_qso.png" width=450>
</center>

**5. Reproduce the above plot**