In [1]:
%matplotlib widget
import caesar
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import pickle as pkl
import graphviz 

  self[key]


In [2]:
from sklearn import tree
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV

In [3]:
obj = caesar.load('/Users/ondrea/Downloads/caesar_296.hdf5') # 296=redshift0

yt : [INFO     ] 2021-09-06 15:05:11,447 Opening /Users/ondrea/Downloads/caesar_296.hdf5
yt : [INFO     ] 2021-09-06 15:05:11,668 Found 54135 halos
yt : [INFO     ] 2021-09-06 15:05:11,712 Found 6480 galaxies


In this notebook I use a decision tree classifier to determine the most important causes of quenching in massive galaxies in a single cosmological hydro snapshot at redshift zero.

### https://caesar.readthedocs.io/en/latest/usage.html for info on the data structure used

In [4]:
#load in data
orig_index = np.arange(0,len(obj.galaxies))
sfr = np.array([gal.sfr for gal in obj.galaxies])
stellar_masses = np.array([gal.masses['stellar'] for gal in obj.galaxies])
bh_masses = np.array([gal.masses['bh'] for gal in obj.galaxies])
gas_masses = np.array([gal.masses['gas'] for gal in obj.galaxies])
metallicity = np.array([gal.metallicities['stellar'] for gal in obj.galaxies])                    
temp = obj.galaxies[1].temperatures['mass_weighted']
radius = np.array([gal.radii['stellar_r20'] for gal in obj.galaxies])
ssfr = (sfr/stellar_masses)*1e9 

In [5]:
bh_masses[bh_masses == 0] = 1e-99
stellar_masses[stellar_masses == 0] = 1e-99
gas_masses[gas_masses == 0] = 1e-99
metallicity[metallicity == 0] = 1e-99
sfr[sfr == 0] = 1e-99
ssfr[ssfr == 0] = 1e-99

In [6]:
gal_data = pd.DataFrame({ 'bh_masses': np.log10(bh_masses),'gas_masses' : np.log10(gas_masses),\
                         'metallicity' : np.log10(metallicity), 'radius' : radius})

I have to load in my SVM classified points (See NB Dead_Classification_z=0.ipynb) which will tell me which galaxies are quenched and which are not. I will use this to create a series for my gal_data df indication quenched (1) or not quenched (0). I must do this now as the order will match at this point.

In [7]:
SVMclf = pkl.load(open( './poly_clf.pkl', 'rb' ))

In [8]:
#Check that scales are reasonable etc
#pd.plotting.scatter_matrix(gal_data, alpha=0.2, figsize=(10, 10), diagonal="kde");

In [9]:
arr=np.column_stack((np.log10(stellar_masses),np.log10(ssfr)))

In [10]:
SVMpredictions=SVMclf.predict(arr)

In [11]:
ifig=2;plt.close(ifig);plt.figure(ifig)
plt.scatter(np.log10(stellar_masses),np.log10(ssfr), s=1, c=SVMpredictions, cmap=sns.color_palette('Spectral', as_cmap=True))
plt.xlim(9.5,12);
plt.ylim(-4,1);
plt.xlabel('log stellar mass')
plt.ylabel('log ssfr')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0, 0.5, 'log ssfr')

Looks great. Also note that zero values are counted as dead. Now I will add them to the df.

In [12]:
gal_data['quenched'] = SVMpredictions

In [13]:
gal_data

Unnamed: 0,bh_masses,gas_masses,metallicity,radius,quenched
0,10.290369,-99.0,-1.709673,12.585268,1
1,9.681071,-99.0,-1.686568,7.108331,1
2,9.890138,-99.0,-1.713443,8.093110,1
3,9.811176,-99.0,-1.702289,7.462501,1
4,9.830744,-99.0,-1.691337,7.025398,1
...,...,...,...,...,...
6475,-99.000000,-99.0,-2.948695,2.281374,1
6476,-99.000000,-99.0,-2.864811,2.289696,1
6477,-99.000000,-99.0,-3.026400,1.733394,1
6478,-99.000000,-99.0,-2.740298,2.117305,1


In [14]:
Xfor_tree=np.column_stack((gal_data['bh_masses'].to_numpy(), gal_data['gas_masses'].to_numpy(),\
                          gal_data['metallicity'].to_numpy(), gal_data['radius'].to_numpy(),))

In [15]:
idx = gal_data.index.to_numpy()

In [16]:
train_test_split?

In [17]:
#Split training and test split
X_train, X_test, y_train, y_test, idx_train, idx_test = train_test_split(Xfor_tree, \
                                gal_data['quenched'], idx, test_size=0.25, random_state=1)

In [18]:
tree_classifier = tree.DecisionTreeClassifier()

In [19]:
#tree_classifier?

In [20]:
param_grid = [{'max_depth': [5,10,20,30], 'min_samples_leaf':[200,500,1000],\
              'max_leaf_nodes': [5,10,20,30]}]
grid_search = GridSearchCV(tree_classifier, param_grid, \
                        scoring = 'roc_auc', return_train_score=True)


In [21]:
grid_search.fit(X_train, y_train)
grid_search.best_params_

{'max_depth': 10, 'max_leaf_nodes': 10, 'min_samples_leaf': 200}

In [22]:
clf = tree.DecisionTreeClassifier(max_depth=10, max_leaf_nodes=30, min_samples_leaf=200).fit(X_train, y_train)

In [23]:
featue_importances = grid_search.best_estimator_.feature_importances_
print(featue_importances, gal_data.columns[:4])

[0.02794797 0.34786911 0.53144338 0.09273955] Index(['bh_masses', 'gas_masses', 'metallicity', 'radius'], dtype='object')


In [24]:
dot_data = tree.export_graphviz(clf, feature_names=gal_data.columns[:4]) 
graph = graphviz.Source(dot_data) 
graph.render("tree", format='png') 

'tree.png'

In [25]:
tree_predictions=clf.predict(X_test)

In [26]:
#Plot the decision trees predictions

In [29]:
ifig=3;plt.close(ifig);plt.figure(ifig)
plt.scatter(np.log10(stellar_masses[idx_test]),np.log10(ssfr[idx_test]), s=5, c=tree_predictions, cmap=sns.color_palette('crest', as_cmap=True))
plt.xlim(9.5,12);
plt.ylim(-4,1);
plt.xlabel('log stellar mass')
plt.ylabel('log ssfr')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0, 0.5, 'log ssfr')

In [31]:
ifig=4;plt.close(ifig);plt.figure(ifig)
plt.scatter(gal_data['metallicity'],np.log10(ssfr), s=1, c=SVMpredictions, cmap=sns.color_palette('crest', as_cmap=True),zorder=1)
plt.xlabel('log Z')
plt.ylabel('log ssfr')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Text(0, 0.5, 'log ssfr')